This repository was archived by the owner on Aug 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 19
[5/x] make FSDP2 with float8 all-gather work for Float8Linear #296
Closed
Closed
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import copy | ||
import itertools | ||
import threading | ||
import unittest | ||
from typing import Any, List | ||
|
@@ -11,6 +12,7 @@ | |
Float8DynamicLinear, | ||
WeightWithDynamicFloat8CastTensor, | ||
) | ||
from float8_experimental.float8_linear import Float8Linear, TensorScalingType | ||
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear | ||
from test_fsdp2_common import ( | ||
check_parity_bf16_mp, | ||
|
@@ -74,8 +76,16 @@ def get_local_inp(self, dtype: torch.dtype = torch.float32): | |
dist.broadcast(global_inp, src=0) | ||
return global_inp.view(self.world_size, -1)[self.rank].view(16, 16) | ||
|
||
def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Module: | ||
return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs) | ||
def swap_linear_with_dynamic( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe losing some context but is there a reason why the existing swap function doesnt work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if the question is why do we need |
||
self, module: nn.Module, use_float8_linear=False, **kwargs: Any | ||
) -> nn.Module: | ||
if use_float8_linear: | ||
kwargs["scaling_type_x"] = TensorScalingType.DYNAMIC | ||
kwargs["scaling_type_w"] = TensorScalingType.DYNAMIC | ||
kwargs["scaling_type_dL_dY"] = TensorScalingType.DYNAMIC | ||
return swap_linear_with_float8_linear(module, Float8Linear, **kwargs) | ||
else: | ||
return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs) | ||
|
||
|
||
class TestFloat8MultiProcess(FSDPTest, TestFloat8Common): | ||
|
@@ -85,20 +95,26 @@ def world_size(self) -> int: | |
|
||
@skip_if_lt_x_gpu(2) | ||
def test_transformer_parity_dynamic(self): | ||
for enable_fsdp_fp8_all_gather in [False, True]: | ||
self._test_transformer_parity_dynamic(enable_fsdp_fp8_all_gather) | ||
for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product( | ||
[False, True], [False, True] | ||
): | ||
self._test_transformer_parity_dynamic( | ||
enable_fsdp_fp8_all_gather, use_float8_linear | ||
) | ||
|
||
def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): | ||
def _test_transformer_parity_dynamic( | ||
self, enable_fsdp_fp8_all_gather: bool, use_float8_linear: bool | ||
): | ||
# NOTE: Weight-tying does not compose with fp8 all-gather because the | ||
# embedding weight and output linear weight are tied but only the | ||
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to | ||
# fp8 for that tied weight, incorrectly using fp8 for the embedding. | ||
weight_tying = not enable_fsdp_fp8_all_gather | ||
module = self.init_transformer(weight_tying=weight_tying) | ||
ref_module = copy.deepcopy(module) | ||
ref_module = self.swap_linear_with_dynamic(ref_module).cuda() | ||
ref_module = self.swap_linear_with_dynamic(ref_module, use_float8_linear).cuda() | ||
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): | ||
module = self.swap_linear_with_dynamic(module) | ||
module = self.swap_linear_with_dynamic(module, use_float8_linear) | ||
for submodule in module.modules(): | ||
if isinstance(submodule, TransformerBlock): | ||
fully_shard(submodule) | ||
|
@@ -108,17 +124,24 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool): | |
local_inp = torch.randint( | ||
0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda" | ||
) | ||
# TODO(future): change Float8DynamicLinear to module_cls below, and | ||
# ensure there is no amax syncing for all-dynamic | ||
check_parity_no_mp( | ||
self, ref_module, ref_optim, module, optim, local_inp, Float8DynamicLinear | ||
) | ||
|
||
@skip_if_lt_x_gpu(2) | ||
def test_transformer_memory(self): | ||
"""Tests peak active memory in the forward and backward passes.""" | ||
for enable_fsdp_fp8_all_gather in [False, True]: | ||
self._test_transformer_memory(enable_fsdp_fp8_all_gather) | ||
|
||
def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool): | ||
# for enable_fsdp_fp8_all_gather in [False, True]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can remove comment right? |
||
for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product( | ||
[False, True], [False, True] | ||
): | ||
self._test_transformer_memory(enable_fsdp_fp8_all_gather, use_float8_linear) | ||
|
||
def _test_transformer_memory( | ||
self, enable_fsdp_fp8_all_gather: bool, use_float8_linear: bool | ||
): | ||
torch.manual_seed(42) | ||
# Pre-run a linear forward (gemm and bias) and backward (gemm) to | ||
# allocate the cuBLAS workspaces before measuring the memory usage | ||
|
@@ -141,7 +164,9 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool): | |
# Emulate the fp8 matmul to bypass the scaled matmul op's divisibility | ||
# requirement to use a smaller activation size | ||
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): | ||
model = self.swap_linear_with_dynamic(model, emulate=True) | ||
model = self.swap_linear_with_dynamic( | ||
model, emulate=True, use_float8_linear=use_float8_linear | ||
) | ||
model_unsharded_numel = sum(p.numel() for p in model.parameters()) | ||
model_sharded_numel = (model_unsharded_numel + 1) // 2 | ||
block_lin_weight_numel = 0 | ||
|
@@ -242,16 +267,23 @@ class TestFloat8MultiThread(FSDPTestMultiThread, TestFloat8Common): | |
def world_size(self) -> int: | ||
return 2 | ||
|
||
@unittest.skipIf(not TEST_CUDA, "no cuda") | ||
def test_weight_subclass_dynamic(self): | ||
def _test_weight_subclass_dynamic(self, use_float8_linear): | ||
float8_cls = Float8Linear if use_float8_linear else Float8DynamicLinear | ||
extra_kwargs = {} | ||
if use_float8_linear: | ||
extra_kwargs["scaling_type_x"] = TensorScalingType.DYNAMIC | ||
extra_kwargs["scaling_type_w"] = TensorScalingType.DYNAMIC | ||
extra_kwargs["scaling_type_dL_dY"] = TensorScalingType.DYNAMIC | ||
pass | ||
tensor_cls = WeightWithDynamicFloat8CastTensor | ||
# Check for a single FSDP paramter group | ||
module_fp32 = self.init_single_module() | ||
with set_enable_fsdp_fp8_all_gather(True): | ||
module = swap_linear_with_float8_linear( | ||
module_fp32, | ||
Float8DynamicLinear, | ||
float8_cls, | ||
emulate=True, | ||
**extra_kwargs, | ||
) | ||
self.assertIsInstance(module.weight, tensor_cls) | ||
fully_shard(module) | ||
|
@@ -265,8 +297,9 @@ def test_weight_subclass_dynamic(self): | |
with set_enable_fsdp_fp8_all_gather(True): | ||
module = swap_linear_with_float8_linear( | ||
module, | ||
Float8DynamicLinear, | ||
float8_cls, | ||
emulate=True, | ||
**extra_kwargs, | ||
) | ||
for param_name, param in module.named_parameters(): | ||
if "weight" in param_name: | ||
|
@@ -280,7 +313,14 @@ def test_weight_subclass_dynamic(self): | |
self.assertIsInstance(param.to_local(), tensor_cls) | ||
|
||
@unittest.skipIf(not TEST_CUDA, "no cuda") | ||
def test_fp8_fp32_all_gather_dynamic_comm_size(self): | ||
def test_weight_subclass_float8_dynamic_linear(self): | ||
self._test_weight_subclass_dynamic(use_float8_linear=False) | ||
|
||
@unittest.skipIf(not TEST_CUDA, "no cuda") | ||
def test_weight_subclass_float8_linear(self): | ||
self._test_weight_subclass_dynamic(use_float8_linear=True) | ||
|
||
def _test_fp8_fp32_all_gather_dynamic_comm_size(self, use_float8_linear): | ||
""" | ||
Tests that fp8 all-gather with dynamic scaling communicates the | ||
expected number of bytes. | ||
|
@@ -314,7 +354,7 @@ def get_expected_all_gather_size(module: nn.Module): | |
module_fp32 = self.init_single_module() | ||
ref_module = copy.deepcopy(module_fp32) | ||
with set_enable_fsdp_fp8_all_gather(True): | ||
module = self.swap_linear_with_dynamic(module_fp32) | ||
module = self.swap_linear_with_dynamic(module_fp32, use_float8_linear) | ||
fully_shard(module) | ||
local_inp = self.get_local_inp() | ||
expected_all_gather_size = get_expected_all_gather_size(ref_module) | ||
|
@@ -358,18 +398,30 @@ def get_expected_all_gather_size(module: nn.Module): | |
[s for s in expected_all_gather_sizes for _ in range(self.world_size)], | ||
) | ||
|
||
@unittest.skipIf(not TEST_CUDA, "no cuda") | ||
def test_fp8_fp32_all_gather_float8_dynamic_linear_comm_size(self): | ||
self._test_fp8_fp32_all_gather_dynamic_comm_size(use_float8_linear=False) | ||
|
||
@unittest.skipIf(not TEST_CUDA, "no cuda") | ||
def test_fp8_fp32_all_gather_float8_linear_comm_size(self): | ||
self._test_fp8_fp32_all_gather_dynamic_comm_size(use_float8_linear=True) | ||
|
||
@unittest.skipIf(not TEST_CUDA, "no cuda") | ||
def test_fp32_fp8_single_module_parity(self): | ||
""" | ||
Tests numeric parity for fp32 parameters with fp8 computation with a | ||
single module/FSDP communication group. | ||
""" | ||
for enable_fsdp_fp8_all_gather in [False, True]: | ||
for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product( | ||
[False, True], [False, True] | ||
): | ||
module_fp32 = self.init_single_module() | ||
ref_module = self.swap_linear_with_dynamic(copy.deepcopy(module_fp32)) | ||
ref_module = self.swap_linear_with_dynamic( | ||
copy.deepcopy(module_fp32), use_float8_linear | ||
) | ||
ref_module = ref_module.cuda() | ||
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): | ||
module = self.swap_linear_with_dynamic(module_fp32) | ||
module = self.swap_linear_with_dynamic(module_fp32, use_float8_linear) | ||
fully_shard(module) | ||
ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2) | ||
optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True) | ||
|
@@ -390,12 +442,16 @@ def test_fp32_fp8_multi_module_parity(self): | |
Tests numeric parity for fp32 parameters with fp8 computation with | ||
multiple modules/FSDP communication groups. | ||
""" | ||
for enable_fsdp_fp8_all_gather in [False, True]: | ||
for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product( | ||
[False, True], [False, True] | ||
): | ||
module = self.init_multi_module() | ||
ref_module = copy.deepcopy(module) | ||
ref_module = self.swap_linear_with_dynamic(ref_module).cuda() | ||
ref_module = self.swap_linear_with_dynamic( | ||
ref_module, use_float8_linear | ||
).cuda() | ||
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather): | ||
module = self.swap_linear_with_dynamic(module) | ||
module = self.swap_linear_with_dynamic(module, use_float8_linear) | ||
for submodule in module: | ||
fully_shard(submodule) | ||
fully_shard(module) | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: maybe a more helpful assert message