Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

[5/x] make FSDP2 with float8 all-gather work for Float8Linear #296

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from float8_experimental.float8_dynamic_linear import (
cast_to_float8_e4m3_dynamic,
cast_to_float8_e5m2_dynamic_bw,
WeightWithDynamicFloat8CastTensor,
)

from float8_experimental.float8_tensor import (
Expand Down Expand Up @@ -163,6 +164,7 @@ def __init__(self, *args, **kwargs):
)
# Amax scales should always be kept as float32.
self.always_float32_buffers = set()
emulate = kwargs.pop("emulate", False)
scaling_type_x = kwargs.pop("scaling_type_x", TensorScalingType.DELAYED)
scaling_type_w = kwargs.pop("scaling_type_w", TensorScalingType.DELAYED)
scaling_type_dL_dY = kwargs.pop("scaling_type_dL_dY", TensorScalingType.DELAYED)
Expand All @@ -187,8 +189,12 @@ def __init__(self, *args, **kwargs):
self.create_buffers()

# Defines the behavior of the matmul in the forward and backward pass
self.forward_config = ScaledMMConfig()
self.backward_config = ScaledMMConfig()
self.forward_config = ScaledMMConfig(
emulate, True if not emulate else False, False, config.pad_inner_dim
)
self.backward_config = ScaledMMConfig(
emulate, False, False, config.pad_inner_dim
)

# Note: is_amax_initialized is not a buffer to avoid data dependent
# control flow visible to dynamo
Expand Down Expand Up @@ -428,19 +434,20 @@ def from_float(
scaling_type_x=scaling_type_x,
scaling_type_w=scaling_type_w,
scaling_type_dL_dY=scaling_type_dL_dY,
emulate=emulate,
)
if (
scaling_type_w == TensorScalingType.DYNAMIC
and config.enable_fsdp_fp8_all_gather
):
new_mod.weight = torch.nn.Parameter(
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
)
new_mod.weight = mod.weight
else:
assert not config.enable_fsdp_fp8_all_gather, "unsupported"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: maybe a more helpful assert message

new_mod.weight = mod.weight
new_mod.bias = mod.bias
# need to create buffers again when moving from meta device to
# real device
new_mod.create_buffers()
# Defines the behavior of the matmul in the forward and backward
# Forward we use fast_accum, backwards we do not
# TODO(future PR): move below to the constructor
new_mod.forward_config = ScaledMMConfig(
emulate, True if not emulate else False, False, config.pad_inner_dim
)
new_mod.backward_config = ScaledMMConfig(
emulate, False, False, config.pad_inner_dim
)
return new_mod
104 changes: 80 additions & 24 deletions test/test_fsdp2/test_fsdp2_eager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import itertools
import threading
import unittest
from typing import Any, List
Expand All @@ -11,6 +12,7 @@
Float8DynamicLinear,
WeightWithDynamicFloat8CastTensor,
)
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from test_fsdp2_common import (
check_parity_bf16_mp,
Expand Down Expand Up @@ -74,8 +76,16 @@ def get_local_inp(self, dtype: torch.dtype = torch.float32):
dist.broadcast(global_inp, src=0)
return global_inp.view(self.world_size, -1)[self.rank].view(16, 16)

def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Module:
return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs)
def swap_linear_with_dynamic(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe losing some context but is there a reason why the existing swap function doesnt work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the question is why do we need swap_linear_with_dynamic, we probably don't. Removing that is not related to this PR though so I left it for a future person.

self, module: nn.Module, use_float8_linear=False, **kwargs: Any
) -> nn.Module:
if use_float8_linear:
kwargs["scaling_type_x"] = TensorScalingType.DYNAMIC
kwargs["scaling_type_w"] = TensorScalingType.DYNAMIC
kwargs["scaling_type_dL_dY"] = TensorScalingType.DYNAMIC
return swap_linear_with_float8_linear(module, Float8Linear, **kwargs)
else:
return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs)


class TestFloat8MultiProcess(FSDPTest, TestFloat8Common):
Expand All @@ -85,20 +95,26 @@ def world_size(self) -> int:

@skip_if_lt_x_gpu(2)
def test_transformer_parity_dynamic(self):
for enable_fsdp_fp8_all_gather in [False, True]:
self._test_transformer_parity_dynamic(enable_fsdp_fp8_all_gather)
for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product(
[False, True], [False, True]
):
self._test_transformer_parity_dynamic(
enable_fsdp_fp8_all_gather, use_float8_linear
)

def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
def _test_transformer_parity_dynamic(
self, enable_fsdp_fp8_all_gather: bool, use_float8_linear: bool
):
# NOTE: Weight-tying does not compose with fp8 all-gather because the
# embedding weight and output linear weight are tied but only the
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
# fp8 for that tied weight, incorrectly using fp8 for the embedding.
weight_tying = not enable_fsdp_fp8_all_gather
module = self.init_transformer(weight_tying=weight_tying)
ref_module = copy.deepcopy(module)
ref_module = self.swap_linear_with_dynamic(ref_module).cuda()
ref_module = self.swap_linear_with_dynamic(ref_module, use_float8_linear).cuda()
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
module = self.swap_linear_with_dynamic(module)
module = self.swap_linear_with_dynamic(module, use_float8_linear)
for submodule in module.modules():
if isinstance(submodule, TransformerBlock):
fully_shard(submodule)
Expand All @@ -108,17 +124,24 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
local_inp = torch.randint(
0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda"
)
# TODO(future): change Float8DynamicLinear to module_cls below, and
# ensure there is no amax syncing for all-dynamic
check_parity_no_mp(
self, ref_module, ref_optim, module, optim, local_inp, Float8DynamicLinear
)

@skip_if_lt_x_gpu(2)
def test_transformer_memory(self):
"""Tests peak active memory in the forward and backward passes."""
for enable_fsdp_fp8_all_gather in [False, True]:
self._test_transformer_memory(enable_fsdp_fp8_all_gather)

def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
# for enable_fsdp_fp8_all_gather in [False, True]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove comment right?

for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product(
[False, True], [False, True]
):
self._test_transformer_memory(enable_fsdp_fp8_all_gather, use_float8_linear)

def _test_transformer_memory(
self, enable_fsdp_fp8_all_gather: bool, use_float8_linear: bool
):
torch.manual_seed(42)
# Pre-run a linear forward (gemm and bias) and backward (gemm) to
# allocate the cuBLAS workspaces before measuring the memory usage
Expand All @@ -141,7 +164,9 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
# Emulate the fp8 matmul to bypass the scaled matmul op's divisibility
# requirement to use a smaller activation size
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
model = self.swap_linear_with_dynamic(model, emulate=True)
model = self.swap_linear_with_dynamic(
model, emulate=True, use_float8_linear=use_float8_linear
)
model_unsharded_numel = sum(p.numel() for p in model.parameters())
model_sharded_numel = (model_unsharded_numel + 1) // 2
block_lin_weight_numel = 0
Expand Down Expand Up @@ -242,16 +267,23 @@ class TestFloat8MultiThread(FSDPTestMultiThread, TestFloat8Common):
def world_size(self) -> int:
return 2

@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_weight_subclass_dynamic(self):
def _test_weight_subclass_dynamic(self, use_float8_linear):
float8_cls = Float8Linear if use_float8_linear else Float8DynamicLinear
extra_kwargs = {}
if use_float8_linear:
extra_kwargs["scaling_type_x"] = TensorScalingType.DYNAMIC
extra_kwargs["scaling_type_w"] = TensorScalingType.DYNAMIC
extra_kwargs["scaling_type_dL_dY"] = TensorScalingType.DYNAMIC
pass
tensor_cls = WeightWithDynamicFloat8CastTensor
# Check for a single FSDP paramter group
module_fp32 = self.init_single_module()
with set_enable_fsdp_fp8_all_gather(True):
module = swap_linear_with_float8_linear(
module_fp32,
Float8DynamicLinear,
float8_cls,
emulate=True,
**extra_kwargs,
)
self.assertIsInstance(module.weight, tensor_cls)
fully_shard(module)
Expand All @@ -265,8 +297,9 @@ def test_weight_subclass_dynamic(self):
with set_enable_fsdp_fp8_all_gather(True):
module = swap_linear_with_float8_linear(
module,
Float8DynamicLinear,
float8_cls,
emulate=True,
**extra_kwargs,
)
for param_name, param in module.named_parameters():
if "weight" in param_name:
Expand All @@ -280,7 +313,14 @@ def test_weight_subclass_dynamic(self):
self.assertIsInstance(param.to_local(), tensor_cls)

@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_fp8_fp32_all_gather_dynamic_comm_size(self):
def test_weight_subclass_float8_dynamic_linear(self):
self._test_weight_subclass_dynamic(use_float8_linear=False)

@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_weight_subclass_float8_linear(self):
self._test_weight_subclass_dynamic(use_float8_linear=True)

def _test_fp8_fp32_all_gather_dynamic_comm_size(self, use_float8_linear):
"""
Tests that fp8 all-gather with dynamic scaling communicates the
expected number of bytes.
Expand Down Expand Up @@ -314,7 +354,7 @@ def get_expected_all_gather_size(module: nn.Module):
module_fp32 = self.init_single_module()
ref_module = copy.deepcopy(module_fp32)
with set_enable_fsdp_fp8_all_gather(True):
module = self.swap_linear_with_dynamic(module_fp32)
module = self.swap_linear_with_dynamic(module_fp32, use_float8_linear)
fully_shard(module)
local_inp = self.get_local_inp()
expected_all_gather_size = get_expected_all_gather_size(ref_module)
Expand Down Expand Up @@ -358,18 +398,30 @@ def get_expected_all_gather_size(module: nn.Module):
[s for s in expected_all_gather_sizes for _ in range(self.world_size)],
)

@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_fp8_fp32_all_gather_float8_dynamic_linear_comm_size(self):
self._test_fp8_fp32_all_gather_dynamic_comm_size(use_float8_linear=False)

@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_fp8_fp32_all_gather_float8_linear_comm_size(self):
self._test_fp8_fp32_all_gather_dynamic_comm_size(use_float8_linear=True)

@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_fp32_fp8_single_module_parity(self):
"""
Tests numeric parity for fp32 parameters with fp8 computation with a
single module/FSDP communication group.
"""
for enable_fsdp_fp8_all_gather in [False, True]:
for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product(
[False, True], [False, True]
):
module_fp32 = self.init_single_module()
ref_module = self.swap_linear_with_dynamic(copy.deepcopy(module_fp32))
ref_module = self.swap_linear_with_dynamic(
copy.deepcopy(module_fp32), use_float8_linear
)
ref_module = ref_module.cuda()
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
module = self.swap_linear_with_dynamic(module_fp32)
module = self.swap_linear_with_dynamic(module_fp32, use_float8_linear)
fully_shard(module)
ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2)
optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True)
Expand All @@ -390,12 +442,16 @@ def test_fp32_fp8_multi_module_parity(self):
Tests numeric parity for fp32 parameters with fp8 computation with
multiple modules/FSDP communication groups.
"""
for enable_fsdp_fp8_all_gather in [False, True]:
for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product(
[False, True], [False, True]
):
module = self.init_multi_module()
ref_module = copy.deepcopy(module)
ref_module = self.swap_linear_with_dynamic(ref_module).cuda()
ref_module = self.swap_linear_with_dynamic(
ref_module, use_float8_linear
).cuda()
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
module = self.swap_linear_with_dynamic(module)
module = self.swap_linear_with_dynamic(module, use_float8_linear)
for submodule in module:
fully_shard(submodule)
fully_shard(module)
Expand Down
Loading