1
1
import copy
2
+ import itertools
2
3
import threading
3
4
import unittest
4
5
from typing import Any , List
11
12
Float8DynamicLinear ,
12
13
WeightWithDynamicFloat8CastTensor ,
13
14
)
15
+ from float8_experimental .float8_linear import Float8Linear , TensorScalingType
14
16
from float8_experimental .float8_linear_utils import swap_linear_with_float8_linear
15
17
from test_fsdp2_common import (
16
18
check_parity_bf16_mp ,
@@ -74,8 +76,30 @@ def get_local_inp(self, dtype: torch.dtype = torch.float32):
74
76
dist .broadcast (global_inp , src = 0 )
75
77
return global_inp .view (self .world_size , - 1 )[self .rank ].view (16 , 16 )
76
78
77
- def swap_linear_with_dynamic (self , module : nn .Module , ** kwargs : Any ) -> nn .Module :
78
- return swap_linear_with_float8_linear (module , Float8DynamicLinear , ** kwargs )
79
+ def swap_linear_with_dynamic (
80
+ self , module : nn .Module , use_float8_linear = False , ** kwargs : Any
81
+ ) -> nn .Module :
82
+ skip_fqn_list = [
83
+ "output" ,
84
+ ]
85
+ for layer in range (3 ):
86
+ skip_fqn_list .append (f"layers.{ layer } .attention.wq" )
87
+ skip_fqn_list .append (f"layers.{ layer } .attention.wk" )
88
+ skip_fqn_list .append (f"layers.{ layer } .attention.wv" )
89
+ skip_fqn_list .append (f"layers.{ layer } .attention.wo" )
90
+ skip_fqn_list .append (f"layers.{ layer } .feed_forward.w1" )
91
+ # if layer > 0:
92
+ # skip_fqn_list.append(f"layers.{layer}.feed_forward.w2")
93
+ # Note: with 3 layers, even a single linear leads to divergence
94
+ # with 1 layer, reproes for any layer
95
+ # kwargs["skip_fqn_list"] = skip_fqn_list
96
+ if use_float8_linear :
97
+ kwargs ["scaling_type_x" ] = TensorScalingType .DYNAMIC
98
+ kwargs ["scaling_type_w" ] = TensorScalingType .DYNAMIC
99
+ kwargs ["scaling_type_dL_dY" ] = TensorScalingType .DYNAMIC
100
+ return swap_linear_with_float8_linear (module , Float8Linear , ** kwargs )
101
+ else :
102
+ return swap_linear_with_float8_linear (module , Float8DynamicLinear , ** kwargs )
79
103
80
104
81
105
class TestFloat8MultiProcess (FSDPTest , TestFloat8Common ):
@@ -85,20 +109,26 @@ def world_size(self) -> int:
85
109
86
110
@skip_if_lt_x_gpu (2 )
87
111
def test_transformer_parity_dynamic (self ):
88
- for enable_fsdp_fp8_all_gather in [False , True ]:
89
- self ._test_transformer_parity_dynamic (enable_fsdp_fp8_all_gather )
112
+ for enable_fsdp_fp8_all_gather , use_float8_linear in itertools .product (
113
+ [False , True ], [False , True ]
114
+ ):
115
+ self ._test_transformer_parity_dynamic (
116
+ enable_fsdp_fp8_all_gather , use_float8_linear
117
+ )
90
118
91
- def _test_transformer_parity_dynamic (self , enable_fsdp_fp8_all_gather : bool ):
119
+ def _test_transformer_parity_dynamic (
120
+ self , enable_fsdp_fp8_all_gather : bool , use_float8_linear : bool
121
+ ):
92
122
# NOTE: Weight-tying does not compose with fp8 all-gather because the
93
123
# embedding weight and output linear weight are tied but only the
94
124
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
95
125
# fp8 for that tied weight, incorrectly using fp8 for the embedding.
96
126
weight_tying = not enable_fsdp_fp8_all_gather
97
127
module = self .init_transformer (weight_tying = weight_tying )
98
128
ref_module = copy .deepcopy (module )
99
- ref_module = self .swap_linear_with_dynamic (ref_module ).cuda ()
129
+ ref_module = self .swap_linear_with_dynamic (ref_module , use_float8_linear ).cuda ()
100
130
with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
101
- module = self .swap_linear_with_dynamic (module )
131
+ module = self .swap_linear_with_dynamic (module , use_float8_linear )
102
132
for submodule in module .modules ():
103
133
if isinstance (submodule , TransformerBlock ):
104
134
fully_shard (submodule )
@@ -108,17 +138,24 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
108
138
local_inp = torch .randint (
109
139
0 , ref_module .tok_embeddings .weight .size (0 ), (16 , 16 ), device = "cuda"
110
140
)
141
+ # TODO(future): change Float8DynamicLinear to module_cls below, and
142
+ # ensure there is no amax syncing for all-dynamic
111
143
check_parity_no_mp (
112
144
self , ref_module , ref_optim , module , optim , local_inp , Float8DynamicLinear
113
145
)
114
146
115
147
@skip_if_lt_x_gpu (2 )
116
148
def test_transformer_memory (self ):
117
149
"""Tests peak active memory in the forward and backward passes."""
118
- for enable_fsdp_fp8_all_gather in [False , True ]:
119
- self ._test_transformer_memory (enable_fsdp_fp8_all_gather )
120
-
121
- def _test_transformer_memory (self , enable_fsdp_fp8_all_gather : bool ):
150
+ # for enable_fsdp_fp8_all_gather in [False, True]:
151
+ for enable_fsdp_fp8_all_gather , use_float8_linear in itertools .product (
152
+ [False , True ], [False , True ]
153
+ ):
154
+ self ._test_transformer_memory (enable_fsdp_fp8_all_gather , use_float8_linear )
155
+
156
+ def _test_transformer_memory (
157
+ self , enable_fsdp_fp8_all_gather : bool , use_float8_linear : bool
158
+ ):
122
159
torch .manual_seed (42 )
123
160
# Pre-run a linear forward (gemm and bias) and backward (gemm) to
124
161
# allocate the cuBLAS workspaces before measuring the memory usage
@@ -141,7 +178,9 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
141
178
# Emulate the fp8 matmul to bypass the scaled matmul op's divisibility
142
179
# requirement to use a smaller activation size
143
180
with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
144
- model = self .swap_linear_with_dynamic (model , emulate = True )
181
+ model = self .swap_linear_with_dynamic (
182
+ model , emulate = True , use_float8_linear = use_float8_linear
183
+ )
145
184
model_unsharded_numel = sum (p .numel () for p in model .parameters ())
146
185
model_sharded_numel = (model_unsharded_numel + 1 ) // 2
147
186
block_lin_weight_numel = 0
@@ -242,16 +281,23 @@ class TestFloat8MultiThread(FSDPTestMultiThread, TestFloat8Common):
242
281
def world_size (self ) -> int :
243
282
return 2
244
283
245
- @unittest .skipIf (not TEST_CUDA , "no cuda" )
246
- def test_weight_subclass_dynamic (self ):
284
+ def _test_weight_subclass_dynamic (self , use_float8_linear ):
285
+ float8_cls = Float8Linear if use_float8_linear else Float8DynamicLinear
286
+ extra_kwargs = {}
287
+ if use_float8_linear :
288
+ extra_kwargs ["scaling_type_x" ] = TensorScalingType .DYNAMIC
289
+ extra_kwargs ["scaling_type_w" ] = TensorScalingType .DYNAMIC
290
+ extra_kwargs ["scaling_type_dL_dY" ] = TensorScalingType .DYNAMIC
291
+ pass
247
292
tensor_cls = WeightWithDynamicFloat8CastTensor
248
293
# Check for a single FSDP paramter group
249
294
module_fp32 = self .init_single_module ()
250
295
with set_enable_fsdp_fp8_all_gather (True ):
251
296
module = swap_linear_with_float8_linear (
252
297
module_fp32 ,
253
- Float8DynamicLinear ,
298
+ float8_cls ,
254
299
emulate = True ,
300
+ ** extra_kwargs ,
255
301
)
256
302
self .assertIsInstance (module .weight , tensor_cls )
257
303
fully_shard (module )
@@ -265,8 +311,9 @@ def test_weight_subclass_dynamic(self):
265
311
with set_enable_fsdp_fp8_all_gather (True ):
266
312
module = swap_linear_with_float8_linear (
267
313
module ,
268
- Float8DynamicLinear ,
314
+ float8_cls ,
269
315
emulate = True ,
316
+ ** extra_kwargs ,
270
317
)
271
318
for param_name , param in module .named_parameters ():
272
319
if "weight" in param_name :
@@ -280,7 +327,14 @@ def test_weight_subclass_dynamic(self):
280
327
self .assertIsInstance (param .to_local (), tensor_cls )
281
328
282
329
@unittest .skipIf (not TEST_CUDA , "no cuda" )
283
- def test_fp8_fp32_all_gather_dynamic_comm_size (self ):
330
+ def test_weight_subclass_float8_dynamic_linear (self ):
331
+ self ._test_weight_subclass_dynamic (use_float8_linear = False )
332
+
333
+ @unittest .skipIf (not TEST_CUDA , "no cuda" )
334
+ def test_weight_subclass_float8_linear (self ):
335
+ self ._test_weight_subclass_dynamic (use_float8_linear = True )
336
+
337
+ def _test_fp8_fp32_all_gather_dynamic_comm_size (self , use_float8_linear ):
284
338
"""
285
339
Tests that fp8 all-gather with dynamic scaling communicates the
286
340
expected number of bytes.
@@ -314,7 +368,7 @@ def get_expected_all_gather_size(module: nn.Module):
314
368
module_fp32 = self .init_single_module ()
315
369
ref_module = copy .deepcopy (module_fp32 )
316
370
with set_enable_fsdp_fp8_all_gather (True ):
317
- module = self .swap_linear_with_dynamic (module_fp32 )
371
+ module = self .swap_linear_with_dynamic (module_fp32 , use_float8_linear )
318
372
fully_shard (module )
319
373
local_inp = self .get_local_inp ()
320
374
expected_all_gather_size = get_expected_all_gather_size (ref_module )
@@ -358,18 +412,30 @@ def get_expected_all_gather_size(module: nn.Module):
358
412
[s for s in expected_all_gather_sizes for _ in range (self .world_size )],
359
413
)
360
414
415
+ @unittest .skipIf (not TEST_CUDA , "no cuda" )
416
+ def test_fp8_fp32_all_gather_float8_dynamic_linear_comm_size (self ):
417
+ self ._test_fp8_fp32_all_gather_dynamic_comm_size (use_float8_linear = False )
418
+
419
+ @unittest .skipIf (not TEST_CUDA , "no cuda" )
420
+ def test_fp8_fp32_all_gather_float8_linear_comm_size (self ):
421
+ self ._test_fp8_fp32_all_gather_dynamic_comm_size (use_float8_linear = True )
422
+
361
423
@unittest .skipIf (not TEST_CUDA , "no cuda" )
362
424
def test_fp32_fp8_single_module_parity (self ):
363
425
"""
364
426
Tests numeric parity for fp32 parameters with fp8 computation with a
365
427
single module/FSDP communication group.
366
428
"""
367
- for enable_fsdp_fp8_all_gather in [False , True ]:
429
+ for enable_fsdp_fp8_all_gather , use_float8_linear in itertools .product (
430
+ [False , True ], [False , True ]
431
+ ):
368
432
module_fp32 = self .init_single_module ()
369
- ref_module = self .swap_linear_with_dynamic (copy .deepcopy (module_fp32 ))
433
+ ref_module = self .swap_linear_with_dynamic (
434
+ copy .deepcopy (module_fp32 ), use_float8_linear
435
+ )
370
436
ref_module = ref_module .cuda ()
371
437
with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
372
- module = self .swap_linear_with_dynamic (module_fp32 )
438
+ module = self .swap_linear_with_dynamic (module_fp32 , use_float8_linear )
373
439
fully_shard (module )
374
440
ref_optim = torch .optim .Adam (ref_module .parameters (), lr = 1e-2 )
375
441
optim = torch .optim .Adam (module .parameters (), lr = 1e-2 , foreach = True )
@@ -390,12 +456,16 @@ def test_fp32_fp8_multi_module_parity(self):
390
456
Tests numeric parity for fp32 parameters with fp8 computation with
391
457
multiple modules/FSDP communication groups.
392
458
"""
393
- for enable_fsdp_fp8_all_gather in [False , True ]:
459
+ for enable_fsdp_fp8_all_gather , use_float8_linear in itertools .product (
460
+ [False , True ], [False , True ]
461
+ ):
394
462
module = self .init_multi_module ()
395
463
ref_module = copy .deepcopy (module )
396
- ref_module = self .swap_linear_with_dynamic (ref_module ).cuda ()
464
+ ref_module = self .swap_linear_with_dynamic (
465
+ ref_module , use_float8_linear
466
+ ).cuda ()
397
467
with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
398
- module = self .swap_linear_with_dynamic (module )
468
+ module = self .swap_linear_with_dynamic (module , use_float8_linear )
399
469
for submodule in module :
400
470
fully_shard (submodule )
401
471
fully_shard (module )
0 commit comments