Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

delete swap_linear_with_dynamic from fsdp2 eager test case #311

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions test/test_fsdp2/test_fsdp2_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,6 @@ def get_local_inp(self, dtype: torch.dtype = torch.float32):
dist.broadcast(global_inp, src=0)
return global_inp.view(self.world_size, -1)[self.rank].view(16, 16)

def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Module:
kwargs["scaling_type_x"] = TensorScalingType.DYNAMIC
kwargs["scaling_type_w"] = TensorScalingType.DYNAMIC
kwargs["scaling_type_dL_dY"] = TensorScalingType.DYNAMIC
return swap_linear_with_float8_linear(module, **kwargs)


class TestFloat8MultiProcess(FSDPTest, TestFloat8Common):
@property
Expand Down Expand Up @@ -106,11 +100,11 @@ def _test_transformer_parity_dynamic(
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
# fp8 for that tied weight, incorrectly using fp8 for the embedding.
weight_tying = not enable_fsdp_fp8_all_gather
module = self.init_transformer(weight_tying=weight_tying)
module = self.init_transformer(weight_tying=weight_tying).cuda()
ref_module = copy.deepcopy(module)
ref_module = self.swap_linear_with_dynamic(ref_module).cuda()
swap_linear_with_float8_linear(ref_module)
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
module = self.swap_linear_with_dynamic(module)
swap_linear_with_float8_linear(module)
for submodule in module.modules():
if isinstance(submodule, TransformerBlock):
fully_shard(submodule)
Expand Down Expand Up @@ -153,7 +147,7 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
# Emulate the fp8 matmul to bypass the scaled matmul op's divisibility
# requirement to use a smaller activation size
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
model = self.swap_linear_with_dynamic(model, emulate=True)
swap_linear_with_float8_linear(model, emulate=True)
model_unsharded_numel = sum(p.numel() for p in model.parameters())
model_sharded_numel = (model_unsharded_numel + 1) // 2
block_lin_weight_numel = 0
Expand Down Expand Up @@ -331,7 +325,8 @@ def get_expected_all_gather_size(module: nn.Module):
module_fp32 = self.init_single_module()
ref_module = copy.deepcopy(module_fp32)
with set_enable_fsdp_fp8_all_gather(True):
module = self.swap_linear_with_dynamic(module_fp32)
module_fp32 = swap_linear_with_float8_linear(module_fp32)
module = module_fp32
fully_shard(module)
local_inp = self.get_local_inp()
expected_all_gather_size = get_expected_all_gather_size(ref_module)
Expand Down Expand Up @@ -359,7 +354,7 @@ def get_expected_all_gather_size(module: nn.Module):
module = self.init_multi_module()
ref_module = copy.deepcopy(module)
with set_enable_fsdp_fp8_all_gather(True):
module = self.swap_linear_with_dynamic(module)
module = swap_linear_with_float8_linear(module)
for submodule in module:
fully_shard(submodule)
fully_shard(module)
Expand All @@ -383,10 +378,11 @@ def test_fp32_fp8_single_module_parity(self):
"""
for enable_fsdp_fp8_all_gather in [False, True]:
module_fp32 = self.init_single_module()
ref_module = self.swap_linear_with_dynamic(copy.deepcopy(module_fp32))
ref_module = copy.deepcopy(module_fp32)
ref_module = swap_linear_with_float8_linear(ref_module)
ref_module = ref_module.cuda()
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
module = self.swap_linear_with_dynamic(module_fp32)
module = swap_linear_with_float8_linear(module_fp32)
fully_shard(module)
ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2)
optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True)
Expand All @@ -407,11 +403,11 @@ def test_fp32_fp8_multi_module_parity(self):
multiple modules/FSDP communication groups.
"""
for enable_fsdp_fp8_all_gather in [False, True]:
module = self.init_multi_module()
module = self.init_multi_module().cuda()
ref_module = copy.deepcopy(module)
ref_module = self.swap_linear_with_dynamic(ref_module).cuda()
ref_module = swap_linear_with_float8_linear(ref_module)
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
module = self.swap_linear_with_dynamic(module)
module = swap_linear_with_float8_linear(module)
for submodule in module:
fully_shard(submodule)
fully_shard(module)
Expand Down
Loading