Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 412222b

Browse files
vkuzofacebook-github-bot
authored andcommitted
make FSDP2 with float8 all-gather work for Float8Linear (#296)
Summary: Pull Request resolved: #296 Adds test coverage for `Float8Linear` with all dynamic scaling and FSDP2 with float8 all-gather. To make the tests pass, fixes a bug with initilization ordering in `Float8Linear.from_float`, we need to have the right forward config set before stashing it on the weight wrapper. Reviewed By: drisspg Differential Revision: D59305793 fbshipit-source-id: d1b207657f5e036801a0efb2d11e9f6ea547f148
1 parent 3ec9665 commit 412222b

File tree

2 files changed

+99
-36
lines changed

2 files changed

+99
-36
lines changed

float8_experimental/float8_linear.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from float8_experimental.float8_dynamic_linear import (
2020
cast_to_float8_e4m3_dynamic,
2121
cast_to_float8_e5m2_dynamic_bw,
22+
WeightWithDynamicFloat8CastTensor,
2223
)
2324

2425
from float8_experimental.float8_tensor import (
@@ -163,6 +164,7 @@ def __init__(self, *args, **kwargs):
163164
)
164165
# Amax scales should always be kept as float32.
165166
self.always_float32_buffers = set()
167+
emulate = kwargs.pop("emulate", False)
166168
scaling_type_x = kwargs.pop("scaling_type_x", TensorScalingType.DELAYED)
167169
scaling_type_w = kwargs.pop("scaling_type_w", TensorScalingType.DELAYED)
168170
scaling_type_dL_dY = kwargs.pop("scaling_type_dL_dY", TensorScalingType.DELAYED)
@@ -187,8 +189,12 @@ def __init__(self, *args, **kwargs):
187189
self.create_buffers()
188190

189191
# Defines the behavior of the matmul in the forward and backward pass
190-
self.forward_config = ScaledMMConfig()
191-
self.backward_config = ScaledMMConfig()
192+
self.forward_config = ScaledMMConfig(
193+
emulate, True if not emulate else False, False, config.pad_inner_dim
194+
)
195+
self.backward_config = ScaledMMConfig(
196+
emulate, False, False, config.pad_inner_dim
197+
)
192198

193199
# Note: is_amax_initialized is not a buffer to avoid data dependent
194200
# control flow visible to dynamo
@@ -428,19 +434,20 @@ def from_float(
428434
scaling_type_x=scaling_type_x,
429435
scaling_type_w=scaling_type_w,
430436
scaling_type_dL_dY=scaling_type_dL_dY,
437+
emulate=emulate,
438+
)
439+
if (
440+
scaling_type_w == TensorScalingType.DYNAMIC
441+
and config.enable_fsdp_fp8_all_gather
442+
):
443+
new_mod.weight = torch.nn.Parameter(
444+
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
431445
)
432-
new_mod.weight = mod.weight
446+
else:
447+
assert not config.enable_fsdp_fp8_all_gather, "unsupported"
448+
new_mod.weight = mod.weight
433449
new_mod.bias = mod.bias
434450
# need to create buffers again when moving from meta device to
435451
# real device
436452
new_mod.create_buffers()
437-
# Defines the behavior of the matmul in the forward and backward
438-
# Forward we use fast_accum, backwards we do not
439-
# TODO(future PR): move below to the constructor
440-
new_mod.forward_config = ScaledMMConfig(
441-
emulate, True if not emulate else False, False, config.pad_inner_dim
442-
)
443-
new_mod.backward_config = ScaledMMConfig(
444-
emulate, False, False, config.pad_inner_dim
445-
)
446453
return new_mod

test/test_fsdp2/test_fsdp2_eager.py

Lines changed: 80 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import itertools
23
import threading
34
import unittest
45
from typing import Any, List
@@ -11,6 +12,7 @@
1112
Float8DynamicLinear,
1213
WeightWithDynamicFloat8CastTensor,
1314
)
15+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
1416
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
1517
from test_fsdp2_common import (
1618
check_parity_bf16_mp,
@@ -74,8 +76,16 @@ def get_local_inp(self, dtype: torch.dtype = torch.float32):
7476
dist.broadcast(global_inp, src=0)
7577
return global_inp.view(self.world_size, -1)[self.rank].view(16, 16)
7678

77-
def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Module:
78-
return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs)
79+
def swap_linear_with_dynamic(
80+
self, module: nn.Module, use_float8_linear=False, **kwargs: Any
81+
) -> nn.Module:
82+
if use_float8_linear:
83+
kwargs["scaling_type_x"] = TensorScalingType.DYNAMIC
84+
kwargs["scaling_type_w"] = TensorScalingType.DYNAMIC
85+
kwargs["scaling_type_dL_dY"] = TensorScalingType.DYNAMIC
86+
return swap_linear_with_float8_linear(module, Float8Linear, **kwargs)
87+
else:
88+
return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs)
7989

8090

8191
class TestFloat8MultiProcess(FSDPTest, TestFloat8Common):
@@ -85,20 +95,26 @@ def world_size(self) -> int:
8595

8696
@skip_if_lt_x_gpu(2)
8797
def test_transformer_parity_dynamic(self):
88-
for enable_fsdp_fp8_all_gather in [False, True]:
89-
self._test_transformer_parity_dynamic(enable_fsdp_fp8_all_gather)
98+
for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product(
99+
[False, True], [False, True]
100+
):
101+
self._test_transformer_parity_dynamic(
102+
enable_fsdp_fp8_all_gather, use_float8_linear
103+
)
90104

91-
def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
105+
def _test_transformer_parity_dynamic(
106+
self, enable_fsdp_fp8_all_gather: bool, use_float8_linear: bool
107+
):
92108
# NOTE: Weight-tying does not compose with fp8 all-gather because the
93109
# embedding weight and output linear weight are tied but only the
94110
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
95111
# fp8 for that tied weight, incorrectly using fp8 for the embedding.
96112
weight_tying = not enable_fsdp_fp8_all_gather
97113
module = self.init_transformer(weight_tying=weight_tying)
98114
ref_module = copy.deepcopy(module)
99-
ref_module = self.swap_linear_with_dynamic(ref_module).cuda()
115+
ref_module = self.swap_linear_with_dynamic(ref_module, use_float8_linear).cuda()
100116
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
101-
module = self.swap_linear_with_dynamic(module)
117+
module = self.swap_linear_with_dynamic(module, use_float8_linear)
102118
for submodule in module.modules():
103119
if isinstance(submodule, TransformerBlock):
104120
fully_shard(submodule)
@@ -108,17 +124,24 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
108124
local_inp = torch.randint(
109125
0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda"
110126
)
127+
# TODO(future): change Float8DynamicLinear to module_cls below, and
128+
# ensure there is no amax syncing for all-dynamic
111129
check_parity_no_mp(
112130
self, ref_module, ref_optim, module, optim, local_inp, Float8DynamicLinear
113131
)
114132

115133
@skip_if_lt_x_gpu(2)
116134
def test_transformer_memory(self):
117135
"""Tests peak active memory in the forward and backward passes."""
118-
for enable_fsdp_fp8_all_gather in [False, True]:
119-
self._test_transformer_memory(enable_fsdp_fp8_all_gather)
120-
121-
def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
136+
# for enable_fsdp_fp8_all_gather in [False, True]:
137+
for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product(
138+
[False, True], [False, True]
139+
):
140+
self._test_transformer_memory(enable_fsdp_fp8_all_gather, use_float8_linear)
141+
142+
def _test_transformer_memory(
143+
self, enable_fsdp_fp8_all_gather: bool, use_float8_linear: bool
144+
):
122145
torch.manual_seed(42)
123146
# Pre-run a linear forward (gemm and bias) and backward (gemm) to
124147
# allocate the cuBLAS workspaces before measuring the memory usage
@@ -141,7 +164,9 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
141164
# Emulate the fp8 matmul to bypass the scaled matmul op's divisibility
142165
# requirement to use a smaller activation size
143166
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
144-
model = self.swap_linear_with_dynamic(model, emulate=True)
167+
model = self.swap_linear_with_dynamic(
168+
model, emulate=True, use_float8_linear=use_float8_linear
169+
)
145170
model_unsharded_numel = sum(p.numel() for p in model.parameters())
146171
model_sharded_numel = (model_unsharded_numel + 1) // 2
147172
block_lin_weight_numel = 0
@@ -242,16 +267,23 @@ class TestFloat8MultiThread(FSDPTestMultiThread, TestFloat8Common):
242267
def world_size(self) -> int:
243268
return 2
244269

245-
@unittest.skipIf(not TEST_CUDA, "no cuda")
246-
def test_weight_subclass_dynamic(self):
270+
def _test_weight_subclass_dynamic(self, use_float8_linear):
271+
float8_cls = Float8Linear if use_float8_linear else Float8DynamicLinear
272+
extra_kwargs = {}
273+
if use_float8_linear:
274+
extra_kwargs["scaling_type_x"] = TensorScalingType.DYNAMIC
275+
extra_kwargs["scaling_type_w"] = TensorScalingType.DYNAMIC
276+
extra_kwargs["scaling_type_dL_dY"] = TensorScalingType.DYNAMIC
277+
pass
247278
tensor_cls = WeightWithDynamicFloat8CastTensor
248279
# Check for a single FSDP paramter group
249280
module_fp32 = self.init_single_module()
250281
with set_enable_fsdp_fp8_all_gather(True):
251282
module = swap_linear_with_float8_linear(
252283
module_fp32,
253-
Float8DynamicLinear,
284+
float8_cls,
254285
emulate=True,
286+
**extra_kwargs,
255287
)
256288
self.assertIsInstance(module.weight, tensor_cls)
257289
fully_shard(module)
@@ -265,8 +297,9 @@ def test_weight_subclass_dynamic(self):
265297
with set_enable_fsdp_fp8_all_gather(True):
266298
module = swap_linear_with_float8_linear(
267299
module,
268-
Float8DynamicLinear,
300+
float8_cls,
269301
emulate=True,
302+
**extra_kwargs,
270303
)
271304
for param_name, param in module.named_parameters():
272305
if "weight" in param_name:
@@ -280,7 +313,14 @@ def test_weight_subclass_dynamic(self):
280313
self.assertIsInstance(param.to_local(), tensor_cls)
281314

282315
@unittest.skipIf(not TEST_CUDA, "no cuda")
283-
def test_fp8_fp32_all_gather_dynamic_comm_size(self):
316+
def test_weight_subclass_float8_dynamic_linear(self):
317+
self._test_weight_subclass_dynamic(use_float8_linear=False)
318+
319+
@unittest.skipIf(not TEST_CUDA, "no cuda")
320+
def test_weight_subclass_float8_linear(self):
321+
self._test_weight_subclass_dynamic(use_float8_linear=True)
322+
323+
def _test_fp8_fp32_all_gather_dynamic_comm_size(self, use_float8_linear):
284324
"""
285325
Tests that fp8 all-gather with dynamic scaling communicates the
286326
expected number of bytes.
@@ -314,7 +354,7 @@ def get_expected_all_gather_size(module: nn.Module):
314354
module_fp32 = self.init_single_module()
315355
ref_module = copy.deepcopy(module_fp32)
316356
with set_enable_fsdp_fp8_all_gather(True):
317-
module = self.swap_linear_with_dynamic(module_fp32)
357+
module = self.swap_linear_with_dynamic(module_fp32, use_float8_linear)
318358
fully_shard(module)
319359
local_inp = self.get_local_inp()
320360
expected_all_gather_size = get_expected_all_gather_size(ref_module)
@@ -358,18 +398,30 @@ def get_expected_all_gather_size(module: nn.Module):
358398
[s for s in expected_all_gather_sizes for _ in range(self.world_size)],
359399
)
360400

401+
@unittest.skipIf(not TEST_CUDA, "no cuda")
402+
def test_fp8_fp32_all_gather_float8_dynamic_linear_comm_size(self):
403+
self._test_fp8_fp32_all_gather_dynamic_comm_size(use_float8_linear=False)
404+
405+
@unittest.skipIf(not TEST_CUDA, "no cuda")
406+
def test_fp8_fp32_all_gather_float8_linear_comm_size(self):
407+
self._test_fp8_fp32_all_gather_dynamic_comm_size(use_float8_linear=True)
408+
361409
@unittest.skipIf(not TEST_CUDA, "no cuda")
362410
def test_fp32_fp8_single_module_parity(self):
363411
"""
364412
Tests numeric parity for fp32 parameters with fp8 computation with a
365413
single module/FSDP communication group.
366414
"""
367-
for enable_fsdp_fp8_all_gather in [False, True]:
415+
for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product(
416+
[False, True], [False, True]
417+
):
368418
module_fp32 = self.init_single_module()
369-
ref_module = self.swap_linear_with_dynamic(copy.deepcopy(module_fp32))
419+
ref_module = self.swap_linear_with_dynamic(
420+
copy.deepcopy(module_fp32), use_float8_linear
421+
)
370422
ref_module = ref_module.cuda()
371423
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
372-
module = self.swap_linear_with_dynamic(module_fp32)
424+
module = self.swap_linear_with_dynamic(module_fp32, use_float8_linear)
373425
fully_shard(module)
374426
ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2)
375427
optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True)
@@ -390,12 +442,16 @@ def test_fp32_fp8_multi_module_parity(self):
390442
Tests numeric parity for fp32 parameters with fp8 computation with
391443
multiple modules/FSDP communication groups.
392444
"""
393-
for enable_fsdp_fp8_all_gather in [False, True]:
445+
for enable_fsdp_fp8_all_gather, use_float8_linear in itertools.product(
446+
[False, True], [False, True]
447+
):
394448
module = self.init_multi_module()
395449
ref_module = copy.deepcopy(module)
396-
ref_module = self.swap_linear_with_dynamic(ref_module).cuda()
450+
ref_module = self.swap_linear_with_dynamic(
451+
ref_module, use_float8_linear
452+
).cuda()
397453
with set_enable_fsdp_fp8_all_gather(enable_fsdp_fp8_all_gather):
398-
module = self.swap_linear_with_dynamic(module)
454+
module = self.swap_linear_with_dynamic(module, use_float8_linear)
399455
for submodule in module:
400456
fully_shard(submodule)
401457
fully_shard(module)

0 commit comments

Comments
 (0)