1
1
import copy
2
+ import itertools
2
3
import threading
3
4
import unittest
4
5
from typing import Any , List
11
12
Float8DynamicLinear ,
12
13
WeightWithDynamicFloat8CastTensor ,
13
14
)
15
+ from float8_experimental .float8_linear import Float8Linear , TensorScalingType
14
16
from float8_experimental .float8_linear_utils import swap_linear_with_float8_linear
15
17
from test_fsdp2_common import (
16
18
check_parity_bf16_mp ,
@@ -74,8 +76,16 @@ def get_local_inp(self, dtype: torch.dtype = torch.float32):
74
76
dist .broadcast (global_inp , src = 0 )
75
77
return global_inp .view (self .world_size , - 1 )[self .rank ].view (16 , 16 )
76
78
77
- def swap_linear_with_dynamic (self , module : nn .Module , ** kwargs : Any ) -> nn .Module :
78
- return swap_linear_with_float8_linear (module , Float8DynamicLinear , ** kwargs )
79
+ def swap_linear_with_dynamic (
80
+ self , module : nn .Module , use_float8_linear = False , ** kwargs : Any
81
+ ) -> nn .Module :
82
+ if use_float8_linear :
83
+ kwargs ["scaling_type_x" ] = TensorScalingType .DYNAMIC
84
+ kwargs ["scaling_type_w" ] = TensorScalingType .DYNAMIC
85
+ kwargs ["scaling_type_dL_dY" ] = TensorScalingType .DYNAMIC
86
+ return swap_linear_with_float8_linear (module , Float8Linear , ** kwargs )
87
+ else :
88
+ return swap_linear_with_float8_linear (module , Float8DynamicLinear , ** kwargs )
79
89
80
90
81
91
class TestFloat8MultiProcess (FSDPTest , TestFloat8Common ):
@@ -85,20 +95,26 @@ def world_size(self) -> int:
85
95
86
96
@skip_if_lt_x_gpu (2 )
87
97
def test_transformer_parity_dynamic (self ):
88
- for enable_fsdp_fp8_all_gather in [False , True ]:
89
- self ._test_transformer_parity_dynamic (enable_fsdp_fp8_all_gather )
98
+ for enable_fsdp_fp8_all_gather , use_float8_linear in itertools .product (
99
+ [False , True ], [False , True ]
100
+ ):
101
+ self ._test_transformer_parity_dynamic (
102
+ enable_fsdp_fp8_all_gather , use_float8_linear
103
+ )
90
104
91
- def _test_transformer_parity_dynamic (self , enable_fsdp_fp8_all_gather : bool ):
105
+ def _test_transformer_parity_dynamic (
106
+ self , enable_fsdp_fp8_all_gather : bool , use_float8_linear : bool
107
+ ):
92
108
# NOTE: Weight-tying does not compose with fp8 all-gather because the
93
109
# embedding weight and output linear weight are tied but only the
94
110
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
95
111
# fp8 for that tied weight, incorrectly using fp8 for the embedding.
96
112
weight_tying = not enable_fsdp_fp8_all_gather
97
113
module = self .init_transformer (weight_tying = weight_tying )
98
114
ref_module = copy .deepcopy (module )
99
- ref_module = self .swap_linear_with_dynamic (ref_module ).cuda ()
115
+ ref_module = self .swap_linear_with_dynamic (ref_module , use_float8_linear ).cuda ()
100
116
with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
101
- module = self .swap_linear_with_dynamic (module )
117
+ module = self .swap_linear_with_dynamic (module , use_float8_linear )
102
118
for submodule in module .modules ():
103
119
if isinstance (submodule , TransformerBlock ):
104
120
fully_shard (submodule )
@@ -108,17 +124,24 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
108
124
local_inp = torch .randint (
109
125
0 , ref_module .tok_embeddings .weight .size (0 ), (16 , 16 ), device = "cuda"
110
126
)
127
+ # TODO(future): change Float8DynamicLinear to module_cls below, and
128
+ # ensure there is no amax syncing for all-dynamic
111
129
check_parity_no_mp (
112
130
self , ref_module , ref_optim , module , optim , local_inp , Float8DynamicLinear
113
131
)
114
132
115
133
@skip_if_lt_x_gpu (2 )
116
134
def test_transformer_memory (self ):
117
135
"""Tests peak active memory in the forward and backward passes."""
118
- for enable_fsdp_fp8_all_gather in [False , True ]:
119
- self ._test_transformer_memory (enable_fsdp_fp8_all_gather )
120
-
121
- def _test_transformer_memory (self , enable_fsdp_fp8_all_gather : bool ):
136
+ # for enable_fsdp_fp8_all_gather in [False, True]:
137
+ for enable_fsdp_fp8_all_gather , use_float8_linear in itertools .product (
138
+ [False , True ], [False , True ]
139
+ ):
140
+ self ._test_transformer_memory (enable_fsdp_fp8_all_gather , use_float8_linear )
141
+
142
+ def _test_transformer_memory (
143
+ self , enable_fsdp_fp8_all_gather : bool , use_float8_linear : bool
144
+ ):
122
145
torch .manual_seed (42 )
123
146
# Pre-run a linear forward (gemm and bias) and backward (gemm) to
124
147
# allocate the cuBLAS workspaces before measuring the memory usage
@@ -141,7 +164,9 @@ def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
141
164
# Emulate the fp8 matmul to bypass the scaled matmul op's divisibility
142
165
# requirement to use a smaller activation size
143
166
with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
144
- model = self .swap_linear_with_dynamic (model , emulate = True )
167
+ model = self .swap_linear_with_dynamic (
168
+ model , emulate = True , use_float8_linear = use_float8_linear
169
+ )
145
170
model_unsharded_numel = sum (p .numel () for p in model .parameters ())
146
171
model_sharded_numel = (model_unsharded_numel + 1 ) // 2
147
172
block_lin_weight_numel = 0
@@ -242,16 +267,23 @@ class TestFloat8MultiThread(FSDPTestMultiThread, TestFloat8Common):
242
267
def world_size (self ) -> int :
243
268
return 2
244
269
245
- @unittest .skipIf (not TEST_CUDA , "no cuda" )
246
- def test_weight_subclass_dynamic (self ):
270
+ def _test_weight_subclass_dynamic (self , use_float8_linear ):
271
+ float8_cls = Float8Linear if use_float8_linear else Float8DynamicLinear
272
+ extra_kwargs = {}
273
+ if use_float8_linear :
274
+ extra_kwargs ["scaling_type_x" ] = TensorScalingType .DYNAMIC
275
+ extra_kwargs ["scaling_type_w" ] = TensorScalingType .DYNAMIC
276
+ extra_kwargs ["scaling_type_dL_dY" ] = TensorScalingType .DYNAMIC
277
+ pass
247
278
tensor_cls = WeightWithDynamicFloat8CastTensor
248
279
# Check for a single FSDP paramter group
249
280
module_fp32 = self .init_single_module ()
250
281
with set_enable_fsdp_fp8_all_gather (True ):
251
282
module = swap_linear_with_float8_linear (
252
283
module_fp32 ,
253
- Float8DynamicLinear ,
284
+ float8_cls ,
254
285
emulate = True ,
286
+ ** extra_kwargs ,
255
287
)
256
288
self .assertIsInstance (module .weight , tensor_cls )
257
289
fully_shard (module )
@@ -265,8 +297,9 @@ def test_weight_subclass_dynamic(self):
265
297
with set_enable_fsdp_fp8_all_gather (True ):
266
298
module = swap_linear_with_float8_linear (
267
299
module ,
268
- Float8DynamicLinear ,
300
+ float8_cls ,
269
301
emulate = True ,
302
+ ** extra_kwargs ,
270
303
)
271
304
for param_name , param in module .named_parameters ():
272
305
if "weight" in param_name :
@@ -280,7 +313,14 @@ def test_weight_subclass_dynamic(self):
280
313
self .assertIsInstance (param .to_local (), tensor_cls )
281
314
282
315
@unittest .skipIf (not TEST_CUDA , "no cuda" )
283
- def test_fp8_fp32_all_gather_dynamic_comm_size (self ):
316
+ def test_weight_subclass_float8_dynamic_linear (self ):
317
+ self ._test_weight_subclass_dynamic (use_float8_linear = False )
318
+
319
+ @unittest .skipIf (not TEST_CUDA , "no cuda" )
320
+ def test_weight_subclass_float8_linear (self ):
321
+ self ._test_weight_subclass_dynamic (use_float8_linear = True )
322
+
323
+ def _test_fp8_fp32_all_gather_dynamic_comm_size (self , use_float8_linear ):
284
324
"""
285
325
Tests that fp8 all-gather with dynamic scaling communicates the
286
326
expected number of bytes.
@@ -314,7 +354,7 @@ def get_expected_all_gather_size(module: nn.Module):
314
354
module_fp32 = self .init_single_module ()
315
355
ref_module = copy .deepcopy (module_fp32 )
316
356
with set_enable_fsdp_fp8_all_gather (True ):
317
- module = self .swap_linear_with_dynamic (module_fp32 )
357
+ module = self .swap_linear_with_dynamic (module_fp32 , use_float8_linear )
318
358
fully_shard (module )
319
359
local_inp = self .get_local_inp ()
320
360
expected_all_gather_size = get_expected_all_gather_size (ref_module )
@@ -358,18 +398,30 @@ def get_expected_all_gather_size(module: nn.Module):
358
398
[s for s in expected_all_gather_sizes for _ in range (self .world_size )],
359
399
)
360
400
401
+ @unittest .skipIf (not TEST_CUDA , "no cuda" )
402
+ def test_fp8_fp32_all_gather_float8_dynamic_linear_comm_size (self ):
403
+ self ._test_fp8_fp32_all_gather_dynamic_comm_size (use_float8_linear = False )
404
+
405
+ @unittest .skipIf (not TEST_CUDA , "no cuda" )
406
+ def test_fp8_fp32_all_gather_float8_linear_comm_size (self ):
407
+ self ._test_fp8_fp32_all_gather_dynamic_comm_size (use_float8_linear = True )
408
+
361
409
@unittest .skipIf (not TEST_CUDA , "no cuda" )
362
410
def test_fp32_fp8_single_module_parity (self ):
363
411
"""
364
412
Tests numeric parity for fp32 parameters with fp8 computation with a
365
413
single module/FSDP communication group.
366
414
"""
367
- for enable_fsdp_fp8_all_gather in [False , True ]:
415
+ for enable_fsdp_fp8_all_gather , use_float8_linear in itertools .product (
416
+ [False , True ], [False , True ]
417
+ ):
368
418
module_fp32 = self .init_single_module ()
369
- ref_module = self .swap_linear_with_dynamic (copy .deepcopy (module_fp32 ))
419
+ ref_module = self .swap_linear_with_dynamic (
420
+ copy .deepcopy (module_fp32 ), use_float8_linear
421
+ )
370
422
ref_module = ref_module .cuda ()
371
423
with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
372
- module = self .swap_linear_with_dynamic (module_fp32 )
424
+ module = self .swap_linear_with_dynamic (module_fp32 , use_float8_linear )
373
425
fully_shard (module )
374
426
ref_optim = torch .optim .Adam (ref_module .parameters (), lr = 1e-2 )
375
427
optim = torch .optim .Adam (module .parameters (), lr = 1e-2 , foreach = True )
@@ -390,12 +442,16 @@ def test_fp32_fp8_multi_module_parity(self):
390
442
Tests numeric parity for fp32 parameters with fp8 computation with
391
443
multiple modules/FSDP communication groups.
392
444
"""
393
- for enable_fsdp_fp8_all_gather in [False , True ]:
445
+ for enable_fsdp_fp8_all_gather , use_float8_linear in itertools .product (
446
+ [False , True ], [False , True ]
447
+ ):
394
448
module = self .init_multi_module ()
395
449
ref_module = copy .deepcopy (module )
396
- ref_module = self .swap_linear_with_dynamic (ref_module ).cuda ()
450
+ ref_module = self .swap_linear_with_dynamic (
451
+ ref_module , use_float8_linear
452
+ ).cuda ()
397
453
with set_enable_fsdp_fp8_all_gather (enable_fsdp_fp8_all_gather ):
398
- module = self .swap_linear_with_dynamic (module )
454
+ module = self .swap_linear_with_dynamic (module , use_float8_linear )
399
455
for submodule in module :
400
456
fully_shard (submodule )
401
457
fully_shard (module )
0 commit comments