Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 29789ae

Browse files
committed
[5/x] make FSDP2 with float8 all-gather work for Float8Linear
Summary: Adds test coverage for `Float8Linear` with all dynamic scaling and FSDP2 with float8 all-gather. To make the tests pass, fixes a bug with initilization ordering in `Float8Linear.from_float`, we need to have the right forward config set before stashing it on the weight wrapper. Test Plan: ``` python test/test_fsdp2/test_fsdp2_eager.py /test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: b6d6525 Pull Request resolved: #296
1 parent 1f0166b commit 29789ae

File tree

2 files changed

+113
-36
lines changed

2 files changed

+113
-36
lines changed

float8_experimental/float8_linear.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from float8_experimental.float8_dynamic_linear import (
2020
cast_to_float8_e4m3_dynamic,
2121
cast_to_float8_e5m2_dynamic_bw,
22+
WeightWithDynamicFloat8CastTensor,
2223
)
2324

2425
from float8_experimental.float8_tensor import (
@@ -163,6 +164,7 @@ def __init__(self, *args, **kwargs):
163164
)
164165
# Amax scales should always be kept as float32.
165166
self.always_float32_buffers = set()
167+
emulate = kwargs.pop("emulate", False)
166168
scaling_type_x = kwargs.pop("scaling_type_x", TensorScalingType.DELAYED)
167169
scaling_type_w = kwargs.pop("scaling_type_w", TensorScalingType.DELAYED)
168170
scaling_type_dL_dY = kwargs.pop("scaling_type_dL_dY", TensorScalingType.DELAYED)
@@ -187,8 +189,12 @@ def __init__(self, *args, **kwargs):
187189
self.create_buffers()
188190

189191
# Defines the behavior of the matmul in the forward and backward pass
190-
self.forward_config = ScaledMMConfig()
191-
self.backward_config = ScaledMMConfig()
192+
self.forward_config = ScaledMMConfig(
193+
emulate, True if not emulate else False, False, config.pad_inner_dim
194+
)
195+
self.backward_config = ScaledMMConfig(
196+
emulate, False, False, config.pad_inner_dim
197+
)
192198

193199
# Note: is_amax_initialized is not a buffer to avoid data dependent
194200
# control flow visible to dynamo
@@ -428,19 +434,20 @@ def from_float(
428434
scaling_type_x=scaling_type_x,
429435
scaling_type_w=scaling_type_w,
430436
scaling_type_dL_dY=scaling_type_dL_dY,
437+
emulate=emulate,
438+
)
439+
if (
440+
scaling_type_w == TensorScalingType.DYNAMIC
441+
and config.enable_fsdp_fp8_all_gather
442+
):
443+
new_mod.weight = torch.nn.Parameter(
444+
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
431445
)
432-
new_mod.weight = mod.weight
446+
else:
447+
assert not config.enable_fsdp_fp8_all_gather, "unsupported"
448+
new_mod.weight = mod.weight
433449
new_mod.bias = mod.bias
434450
# need to create buffers again when moving from meta device to
435451
# real device
436452
new_mod.create_buffers()
437-
# Defines the behavior of the matmul in the forward and backward
438-
# Forward we use fast_accum, backwards we do not
439-
# TODO(future PR): move below to the constructor
440-
new_mod.forward_config = ScaledMMConfig(
441-
emulate, True if not emulate else False, False, config.pad_inner_dim
442-
)
443-
new_mod.backward_config = ScaledMMConfig(
444-
emulate, False, False, config.pad_inner_dim
445-
)
446453
return new_mod

test/test_fsdp2/test_fsdp2_eager.py

Lines changed: 94 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import itertools
23
import threading
34
import unittest
45
from typing import Any, List
@@ -11,6 +12,7 @@
1112
Float8DynamicLinear,
1213
WeightWithDynamicFloat8CastTensor,
1314
)
15+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
1416
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
1517
from test_fsdp2_common import (
1618
check_parity_bf16_mp,
@@ -74,8 +76,30 @@ def get_local_inp(self, dtype: torch.dtype = torch.float32):
7476
dist.broadcast(global_inp, src=0)
7577
return global_inp.view(self.world_size, -1)[self.rank].view(16, 16)
7678

77-
def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Module:
78-
return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs)
79+
def swap_linear_with_dynamic(
80+
self, module: nn.Module, use_float8_linear=False, **kwargs: Any
81+
) -> nn.Module:
82+
skip_fqn_list = [
83+
"output",
84+
]
85+
for layer in range(3):
86+
skip_fqn_list.append(f"layers.{layer}.attention.wq")
87+
skip_fqn_list.append(f"layers.{layer}.attention.wk")
88+
skip_fqn_list.append(f"layers.{layer}.attention.wv")
89+
skip_fqn_list.append(f"layers.{layer}.attention.wo")
90+
skip_fqn_list.append(f"layers.{layer}.feed_forward.w1")
91+
# if layer > 0:
92+
# skip_fqn_list.append(f"layers.{layer}.feed_forward.w2")
93+
# Note: with 3 layers, even a single linear leads to divergence
94+
# with 1 layer, reproes for any layer
95+
# kwargs["skip_fqn_list"] = skip_fqn_list
96+
if use_float8_linear:
97+
kwargs["scaling_type_x"] = TensorScalingType.DYNAMIC
98+
kwargs["scaling_type_w"] = TensorScalingType.DYNAMIC
99+
kwargs["scaling_type_dL_dY"] = TensorScalingType.DYNAMIC
100+
return swap_linear_with_float8_linear(module, Float8Linear, **kwargs)
101+
else:
102+
return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs)
79103

80104

81105
class TestFloat8MultiProcess(FSDPTest, TestFloat8Common):
@@ -85,20 +109,26 @@ def world_size(self) -> int:
85109

86110
@skip_if_lt_x_gpu(2)
87111
def test_transformer_parity_dynamic(self):
88-
for enable_fsdp_fp8_all_gather in [False, True]:
89-
self._test_transformer_parity_dynamic(enable_fsdp_fp8_all_gather)
112+
for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product(
113+
[False, True], [False, True]
114+
):
115+
self._test_transformer_parity_dynamic(
116+
enable_fsdp_fp8_all_gather, use_float8_linear
117+
)
90118

91-
def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
119+
def _test_transformer_parity_dynamic(
120+
self, enable_fsdp_fp8_all_gather: bool, use_float8_linear: bool
121+
):
92122
# NOTE: Weight-tying does not compose with fp8 all-gather because the
93123
# embedding weight and output linear weight are tied but only the
94124
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
95125
# fp8 for that tied weight, incorrectly using fp8 for the embedding.
96126
weight_tying = not enable_fsdp_fp8_all_gather
97127
module = self.init_transformer(weight_tying=weight_tying)
98128
ref_module = copy.deepcopy(module)
99-
ref_module = self.swap_linear_with_dynamic(ref_module).cuda()
129+
ref_module = self.swap_linear_with_dynamic(ref_module, use_float8_linear).cuda()
100130
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
101-
module = self.swap_linear_with_dynamic(module)
131+
module = self.swap_linear_with_dynamic(module, use_float8_linear)
102132
for submodule in module.modules():
103133
if isinstance(submodule, TransformerBlock):
104134
fully_shard(submodule)
@@ -108,17 +138,24 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
108138
local_inp = torch.randint(
109139
0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda"
110140
)
141+
# TODO(future): change Float8DynamicLinear to module_cls below, and
142+
# ensure there is no amax syncing for all-dynamic
111143
check_parity_no_mp(
112144
self, ref_module, ref_optim, module, optim, local_inp, Float8DynamicLinear
113145
)
114146

115147
@skip_if_lt_x_gpu(2)
116148
def test_transformer_memory(self):
117149
"""Tests peak active memory in the forward and backward passes."""
118-
for enable_fsdp_fp8_all_gather in [False, True]:
119-
self._test_transformer_memory(enable_fsdp_fp8_all_gather)
120-
121-
def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
150+
# for enable_fsdp_fp8_all_gather in [False, True]:
151+
for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product(
152+
[False, True], [False, True]
153+
):
154+
self._test_transformer_memory(enable_fsdp_fp8_all_gather, use_float8_linear)
155+
156+
def _test_transformer_memory(
157+
self, enable_fsdp_fp8_all_gather: bool, use_float8_linear: bool
158+
):
122159
torch.manual_seed(42)
123160
# Pre-run a linear forward (gemm and bias) and backward (gemm) to
124161
# allocate the cuBLAS workspaces before measuring the memory usage
@@ -141,7 +178,9 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
141178
# Emulate the fp8 matmul to bypass the scaled matmul op's divisibility
142179
# requirement to use a smaller activation size
143180
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
144-
model = self.swap_linear_with_dynamic(model, emulate=True)
181+
model = self.swap_linear_with_dynamic(
182+
model, emulate=True, use_float8_linear=use_float8_linear
183+
)
145184
model_unsharded_numel = sum(p.numel() for p in model.parameters())
146185
model_sharded_numel = (model_unsharded_numel + 1) // 2
147186
block_lin_weight_numel = 0
@@ -242,16 +281,23 @@ class TestFloat8MultiThread(FSDPTestMultiThread, TestFloat8Common):
242281
def world_size(self) -> int:
243282
return 2
244283

245-
@unittest.skipIf(not TEST_CUDA, "no cuda")
246-
def test_weight_subclass_dynamic(self):
284+
def _test_weight_subclass_dynamic(self, use_float8_linear):
285+
float8_cls = Float8Linear if use_float8_linear else Float8DynamicLinear
286+
extra_kwargs = {}
287+
if use_float8_linear:
288+
extra_kwargs["scaling_type_x"] = TensorScalingType.DYNAMIC
289+
extra_kwargs["scaling_type_w"] = TensorScalingType.DYNAMIC
290+
extra_kwargs["scaling_type_dL_dY"] = TensorScalingType.DYNAMIC
291+
pass
247292
tensor_cls = WeightWithDynamicFloat8CastTensor
248293
# Check for a single FSDP paramter group
249294
module_fp32 = self.init_single_module()
250295
with set_enable_fsdp_fp8_all_gather(True):
251296
module = swap_linear_with_float8_linear(
252297
module_fp32,
253-
Float8DynamicLinear,
298+
float8_cls,
254299
emulate=True,
300+
**extra_kwargs,
255301
)
256302
self.assertIsInstance(module.weight, tensor_cls)
257303
fully_shard(module)
@@ -265,8 +311,9 @@ def test_weight_subclass_dynamic(self):
265311
with set_enable_fsdp_fp8_all_gather(True):
266312
module = swap_linear_with_float8_linear(
267313
module,
268-
Float8DynamicLinear,
314+
float8_cls,
269315
emulate=True,
316+
**extra_kwargs,
270317
)
271318
for param_name, param in module.named_parameters():
272319
if "weight" in param_name:
@@ -280,7 +327,14 @@ def test_weight_subclass_dynamic(self):
280327
self.assertIsInstance(param.to_local(), tensor_cls)
281328

282329
@unittest.skipIf(not TEST_CUDA, "no cuda")
283-
def test_fp8_fp32_all_gather_dynamic_comm_size(self):
330+
def test_weight_subclass_float8_dynamic_linear(self):
331+
self._test_weight_subclass_dynamic(use_float8_linear=False)
332+
333+
@unittest.skipIf(not TEST_CUDA, "no cuda")
334+
def test_weight_subclass_float8_linear(self):
335+
self._test_weight_subclass_dynamic(use_float8_linear=True)
336+
337+
def _test_fp8_fp32_all_gather_dynamic_comm_size(self, use_float8_linear):
284338
"""
285339
Tests that fp8 all-gather with dynamic scaling communicates the
286340
expected number of bytes.
@@ -314,7 +368,7 @@ def get_expected_all_gather_size(module: nn.Module):
314368
module_fp32 = self.init_single_module()
315369
ref_module = copy.deepcopy(module_fp32)
316370
with set_enable_fsdp_fp8_all_gather(True):
317-
module = self.swap_linear_with_dynamic(module_fp32)
371+
module = self.swap_linear_with_dynamic(module_fp32, use_float8_linear)
318372
fully_shard(module)
319373
local_inp = self.get_local_inp()
320374
expected_all_gather_size = get_expected_all_gather_size(ref_module)
@@ -358,18 +412,30 @@ def get_expected_all_gather_size(module: nn.Module):
358412
[s for s in expected_all_gather_sizes for _ in range(self.world_size)],
359413
)
360414

415+
@unittest.skipIf(not TEST_CUDA, "no cuda")
416+
def test_fp8_fp32_all_gather_float8_dynamic_linear_comm_size(self):
417+
self._test_fp8_fp32_all_gather_dynamic_comm_size(use_float8_linear=False)
418+
419+
@unittest.skipIf(not TEST_CUDA, "no cuda")
420+
def test_fp8_fp32_all_gather_float8_linear_comm_size(self):
421+
self._test_fp8_fp32_all_gather_dynamic_comm_size(use_float8_linear=True)
422+
361423
@unittest.skipIf(not TEST_CUDA, "no cuda")
362424
def test_fp32_fp8_single_module_parity(self):
363425
"""
364426
Tests numeric parity for fp32 parameters with fp8 computation with a
365427
single module/FSDP communication group.
366428
"""
367-
for enable_fsdp_fp8_all_gather in [False, True]:
429+
for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product(
430+
[False, True], [False, True]
431+
):
368432
module_fp32 = self.init_single_module()
369-
ref_module = self.swap_linear_with_dynamic(copy.deepcopy(module_fp32))
433+
ref_module = self.swap_linear_with_dynamic(
434+
copy.deepcopy(module_fp32), use_float8_linear
435+
)
370436
ref_module = ref_module.cuda()
371437
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
372-
module = self.swap_linear_with_dynamic(module_fp32)
438+
module = self.swap_linear_with_dynamic(module_fp32, use_float8_linear)
373439
fully_shard(module)
374440
ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2)
375441
optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True)
@@ -390,12 +456,16 @@ def test_fp32_fp8_multi_module_parity(self):
390456
Tests numeric parity for fp32 parameters with fp8 computation with
391457
multiple modules/FSDP communication groups.
392458
"""
393-
for enable_fsdp_fp8_all_gather in [False, True]:
459+
for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product(
460+
[False, True], [False, True]
461+
):
394462
module = self.init_multi_module()
395463
ref_module = copy.deepcopy(module)
396-
ref_module = self.swap_linear_with_dynamic(ref_module).cuda()
464+
ref_module = self.swap_linear_with_dynamic(
465+
ref_module, use_float8_linear
466+
).cuda()
397467
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
398-
module = self.swap_linear_with_dynamic(module)
468+
module = self.swap_linear_with_dynamic(module, use_float8_linear)
399469
for submodule in module:
400470
fully_shard(submodule)
401471
fully_shard(module)

0 commit comments

Comments
 (0)