4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
"""
7
- Test numerics of single GPU vs FSDP of toy model. At a high level:
8
- 1. start with reference input and state dict for a single GPU model
9
- 2. run fw+bw+optim on single GPU, save the results
10
- 3. run fw+bw+optim with FSDP, save the results
11
- 4. verify that the outputs and state dicts after optim update match
12
-
13
- later 1-4 can be repeated for fp16, various combinations of fp8, etc.
7
+ Test numerics of bf16 versus float8 with FSDP on. At a high level:
8
+ 1. start with a reference model, with FSDP on
9
+ 2. run forward + backward + optim for 2 iterations
10
+ 3. repeat 2 with float8 enabled (2 iterations needed for delayed scaling)
11
+ 4. compare outputs and state dict between (2) and (3), should be close
14
12
"""
15
13
14
+ import copy
16
15
import os
17
16
import warnings
18
17
22
21
import torch .distributed as dist
23
22
import torch .multiprocessing as mp
24
23
import torch .nn as nn
25
- from float8_experimental .float8_linear import Float8Linear
24
+ from float8_experimental .float8_linear import Float8Linear , TensorScalingType
26
25
from float8_experimental .float8_linear_utils import (
27
26
swap_linear_with_float8_linear ,
28
27
sync_float8_amax_and_scale_history ,
29
28
)
29
+ from float8_experimental .float8_utils import compute_error
30
30
from torch .distributed .fsdp import (
31
31
FullStateDictConfig ,
32
32
FullyShardedDataParallel as FSDP ,
35
35
36
36
torch .manual_seed (0 )
37
37
38
- # assumes user is running the script from /data/users/{user}/float8_experimental
39
- data_dir = os .path .join (os .path .dirname (__file__ ), "tmp" )
40
- input_fname = os .path .join (data_dir , "input.pt" )
41
- sd_in_fname = os .path .join (data_dir , "sd_in.pt" )
42
- sd_out_single_gpu_fname = os .path .join (data_dir , "sd_out_single_gpu.pt" )
43
- sd_out_fsdp_fname = os .path .join (data_dir , "sd_out_fsdp.pt" )
44
- output_single_gpu_fname = os .path .join (data_dir , "output_single_gpu.pt" )
45
- output_fsdp_fname = os .path .join (data_dir , "output_fsdp.pt" )
46
-
47
38
B , M , K , N = 8 , 8 , 32 , 32
48
39
lr = 0.01
49
- N_ITER = 5
40
+ N_ITER = 2
50
41
51
42
52
43
def setup (rank , world_size ):
@@ -61,15 +52,13 @@ def cleanup():
61
52
dist .destroy_process_group ()
62
53
63
54
64
- def get_model (K , N , is_fp8 , emulate , base_dtype = torch .float32 ):
55
+ def get_model (K , N , base_dtype = torch .float32 ):
65
56
m = nn .Sequential (
66
57
nn .Linear (K , N , dtype = base_dtype ),
67
58
nn .ReLU (),
68
59
nn .Linear (N , N , dtype = base_dtype ),
69
60
nn .ReLU (),
70
61
)
71
- if is_fp8 :
72
- swap_linear_with_float8_linear (m , Float8Linear , emulate = emulate )
73
62
return m
74
63
75
64
@@ -79,52 +68,84 @@ def fsdp_main(rank, world_size, args):
79
68
setup (rank , world_size )
80
69
torch .cuda .set_device (rank )
81
70
82
- # TODO: We set fullgraph as an option. However, it currently doesn't work for fullgraph compile.
83
- # We can investigate and fix it later.
84
- is_fp8 , emulate , base_dtype , compile , fullgraph = args
85
- model = get_model (K , N , is_fp8 = is_fp8 , emulate = emulate , base_dtype = base_dtype ).to (
86
- rank
71
+ emulate , base_dtype , compile , use_weight_dynamic_scaling = args
72
+ model = get_model (K , N , base_dtype = base_dtype ).to (rank )
73
+ model_fp8 = copy .deepcopy (model )
74
+
75
+ scaling_type_w = (
76
+ TensorScalingType .DYNAMIC
77
+ if use_weight_dynamic_scaling
78
+ else TensorScalingType .DELAYED
79
+ )
80
+
81
+ # Note: we only iterate over `scaling_type_w` because FSDP only interacts
82
+ # with weights.
83
+ swap_linear_with_float8_linear (
84
+ model_fp8 ,
85
+ Float8Linear ,
86
+ emulate = False ,
87
+ scaling_type_w = scaling_type_w ,
87
88
)
88
- model . load_state_dict ( torch . load ( sd_in_fname , weights_only = True ))
89
+
89
90
# To compile FSDP, we need use_orig_params to True
90
91
model = FSDP (model , use_orig_params = True )
92
+ model_fp8 = FSDP (model_fp8 , use_orig_params = True )
91
93
# TODO: The following line doesn't work. We should fix it.
92
94
# model = FSDP(torch.compile(model), use_orig_params=True)
93
95
94
- # Note: we need to multiply by world_size here to match single GPU
95
- # optimizer update
96
- optimizer = torch .optim .SGD (model .parameters (), lr = lr * world_size )
96
+ optimizer = torch .optim .SGD (model .parameters (), lr = lr )
97
+ optimizer_fp8 = torch .optim .SGD (model_fp8 .parameters (), lr = lr )
97
98
98
- ref_input_global = torch .load (input_fname , weights_only = True ).to (base_dtype )
99
+ # Note: we need two different inputs to properly measure the impact of
100
+ # delayed scaling, before the first input uses dynamic scaling to
101
+ # populate the buffers
102
+ ref_input_global = [
103
+ torch .randn (B , M , K ).cuda ().to (base_dtype ),
104
+ torch .randn (B , M , K ).cuda ().to (base_dtype ),
105
+ ]
106
+ ref_grad_global = [
107
+ torch .randn (B , M , N ).cuda ().to (base_dtype ),
108
+ torch .randn (B , M , N ).cuda ().to (base_dtype ),
109
+ ]
110
+ ref_input_local = []
111
+ ref_grad_local = []
99
112
100
113
# basic distributed data sampling
101
114
assert B % world_size == 0
102
115
bsz_local_start = int (rank / world_size * B )
103
116
bsz_local_end = int ((rank + 1 ) / world_size * B )
104
- ref_input_local = ref_input_global [bsz_local_start :bsz_local_end ].to (rank )
117
+ for idx in range (N_ITER ):
118
+ ref_input_local .append (
119
+ ref_input_global [idx ][bsz_local_start :bsz_local_end ].to (rank )
120
+ )
121
+ ref_grad_local .append (
122
+ ref_grad_global [idx ][bsz_local_start :bsz_local_end ].to (rank )
123
+ )
105
124
106
125
sync_float8_func = sync_float8_amax_and_scale_history
107
126
if compile :
108
- sync_float8_func = torch .compile (
109
- sync_float8_amax_and_scale_history , fullgraph = fullgraph
110
- )
111
-
112
- def forward_backward (model ):
113
- optimizer .zero_grad ()
114
- y_local = model (ref_input_local )
115
- y_local .sum ().backward ()
116
- sync_float8_func (model )
117
- optimizer .step ()
127
+ sync_float8_func = torch .compile (sync_float8_amax_and_scale_history )
128
+
129
+ def forward_backward (model , optim , is_fp8 , i ):
130
+ optim .zero_grad ()
131
+ y_local = model (ref_input_local [i ])
132
+ y_local .backward (ref_grad_local [i ])
133
+ if is_fp8 :
134
+ sync_float8_func (model )
135
+ optim .step ()
118
136
return y_local
119
137
120
- for iter in range (N_ITER ):
138
+ for i in range (N_ITER ):
121
139
# We first run one iteration without compile, as a workaround to compile float8 layer.
122
140
# In the first iter, float8 layers go to the branches of "self.is_amax_initialized == False"
123
141
# After that, float8 layers go the the branches of "self.is_amax_initialized == True"
124
142
# TODO: Need to fix compile to run wihtout this workaround.
125
- if iter == 1 and compile :
126
- model = torch .compile (model , fullgraph = fullgraph )
127
- y_local = forward_backward (model )
143
+ if i == 1 and compile :
144
+ model = torch .compile (model )
145
+ model_fp8 = torch .compile (model_fp8 )
146
+ y_local = forward_backward (model , optimizer , is_fp8 = False , i = i )
147
+ y_local_fp8 = forward_backward (model_fp8 , optimizer_fp8 , is_fp8 = True , i = i )
148
+ local_sqnr = compute_error (y_local , y_local_fp8 )
128
149
129
150
# get global y
130
151
y_global = [
@@ -133,132 +154,50 @@ def forward_backward(model):
133
154
]
134
155
dist .all_gather (y_global , y_local )
135
156
y_global = torch .cat (y_global , dim = 0 )
157
+ y_global_fp8 = [
158
+ torch .zeros (* y_local_fp8 .shape , dtype = base_dtype ).to (rank )
159
+ for r in range (world_size )
160
+ ]
161
+ dist .all_gather (y_global_fp8 , y_local_fp8 )
162
+ y_global_fp8 = torch .cat (y_global_fp8 , dim = 0 )
136
163
if rank == 0 :
137
- torch .save (y_global , output_fsdp_fname )
164
+ sqnr = compute_error (y_global , y_global_fp8 )
165
+ assert sqnr > 15.0 , f"SQNR of { sqnr } is too low"
138
166
139
167
# get global state dict
140
168
# https://pytorch.org/tutorials/intermediate/FSDP_adavnced_tutorial.html
141
169
dist .barrier ()
142
170
save_policy = FullStateDictConfig (offload_to_cpu = True , rank0_only = True )
143
171
with FSDP .state_dict_type (model , StateDictType .FULL_STATE_DICT , save_policy ):
144
172
cpu_state = model .state_dict ()
173
+ with FSDP .state_dict_type (model_fp8 , StateDictType .FULL_STATE_DICT , save_policy ):
174
+ cpu_state_fp8 = model_fp8 .state_dict ()
145
175
if rank == 0 :
146
- torch .save (cpu_state , sd_out_fsdp_fname )
176
+ for k , v1 in cpu_state .items ():
177
+ v2 = cpu_state_fp8 [k ]
178
+ v1 , v2 = v1 .cpu (), v2 .cpu ()
179
+ sqnr = compute_error (v1 , v2 )
180
+ assert sqnr > 15.0 , f"SQNR of { sqnr } is too low, k: { k } , v1: { v1 } , v2: { v2 } "
147
181
148
182
cleanup ()
149
183
150
184
151
- def run (mode : str , is_fp8 : bool , compile_fsdp : bool = False , fullgraph : bool = False ):
152
- print (f"Mode: { mode } " .center (100 , "-" ))
185
+ def run (compile_fsdp : bool = False , use_weight_dynamic_scaling : bool = False ):
153
186
base_dtype = torch .bfloat16
154
- if not os .path .exists (data_dir ):
155
- os .makedirs (data_dir )
156
187
157
188
emulate = False
158
- if is_fp8 :
159
- if not torch .cuda .is_available ():
160
- warnings .warn ("CUDA not available, running in emulation_mode" )
161
- emulate = True
162
- elif torch .cuda .get_device_capability () < (9 , 0 ):
163
- warnings .warn (
164
- f"CUDA capability { torch .cuda .get_device_capability ()} < (9.0), running in emulation mode"
165
- )
166
- emulate = True
167
-
168
- if mode == "generate" :
169
- # generate reference input
170
- ref_input = torch .randn (B , M , K ).cuda ().to (base_dtype )
171
- model = get_model (
172
- K , N , is_fp8 = is_fp8 , emulate = emulate , base_dtype = base_dtype
173
- ).cuda ()
174
- torch .save (ref_input , input_fname )
175
- torch .save (model .state_dict (), sd_in_fname )
176
-
177
- elif mode == "single_gpu" :
178
- ref_input = torch .load (input_fname , weights_only = True ).to (base_dtype )
179
- model = get_model (
180
- K , N , is_fp8 = is_fp8 , emulate = emulate , base_dtype = base_dtype
181
- ).cuda ()
182
- model .load_state_dict (torch .load (sd_in_fname , weights_only = True ))
183
- optimizer = torch .optim .SGD (model .parameters (), lr = lr )
184
-
185
- def forward_backward ():
186
- optimizer .zero_grad ()
187
- y = model (ref_input )
188
- y .sum ().backward ()
189
- sync_float8_amax_and_scale_history (model )
190
- optimizer .step ()
191
- return y
192
-
193
- for _ in range (N_ITER ):
194
- y = forward_backward ()
195
-
196
- torch .save (y , output_single_gpu_fname )
197
- torch .save (model .state_dict (), sd_out_single_gpu_fname )
198
-
199
- elif mode == "fsdp" :
200
- WORLD_SIZE = torch .cuda .device_count ()
201
- # We only compile for fsdp, and compare the numerics with signle-gpu no-compile
202
- args = (is_fp8 , emulate , base_dtype , compile_fsdp , fullgraph )
203
- mp .spawn (fsdp_main , args = (WORLD_SIZE , args ), nprocs = WORLD_SIZE , join = True )
204
-
205
- elif mode == "analyze" :
206
- y_single_gpu = torch .load (output_single_gpu_fname , weights_only = True ).cpu ()
207
- y_fsdp = torch .load (output_fsdp_fname , weights_only = True ).cpu ()
208
- if is_fp8 and not emulate :
209
- atol , rtol = 2e-2 , 2e-2
210
- else :
211
- atol , rtol = None , None
212
- torch .testing .assert_close (y_single_gpu , y_fsdp , atol = atol , rtol = rtol )
213
- print ("output testing single_gpu vs FSDP success" )
214
-
215
- sd_out_single_gpu = torch .load (sd_out_single_gpu_fname , weights_only = True )
216
- sd_out_fsdp = torch .load (sd_out_fsdp_fname , weights_only = True )
217
- for k , v1 in sd_out_single_gpu .items ():
218
- if compile_fsdp :
219
- # The state-dict for compiled fsdp has a `_orig_mod` prefix
220
- k = f"_orig_mod.{ k } "
221
- v2 = sd_out_fsdp [k ]
222
- v1 , v2 = v1 .cpu (), v2 .cpu ()
223
- if is_fp8 and "noop" in k :
224
- # Note: for fp8 single-node vs FSDP, we are not expected
225
- # to match the scale of the gradients which follow the following
226
- # pattern:
227
- #
228
- # `op(g_prev, out_scale) -> g_fp8 -> cast -> g_fp16 -> reduce`.
229
- #
230
- # Reasoning is the order of operations of calculating the above:
231
- # a. single node:
232
- # 1. calculate dL_dValue and s_dL_dValue
233
- # 2. you're done
234
- # b. FSDP:
235
- # 1. calculate dL_dValue and s_dL_dValue of each slice
236
- # 2. reduce using summation
237
- #
238
- # a and b cannot always match because calculating the scale
239
- # involves taking max(dL_dW), FSDP reduces the gradients, and
240
- # max(abs(a), abs(b)) != max(abs(a + b))
241
- #
242
- # In today's codebase, we do not hit this yet. We expect to hit
243
- # this if we implement TP with activation gradients that both need
244
- # reductions and need fp8 distributed comms. Solution - TBD.
245
-
246
- # noop buffers are unused, so ok for them to not match
247
- pass
248
- else :
249
- try :
250
- if v1 .dtype == torch .bfloat16 and not emulate :
251
- atol , rtol = 2e-2 , 2e-2
252
- else :
253
- if k == "1.fp8_amax_history_x" and not emulate :
254
- atol , rtol = 2e-2 , 6e-3
255
- else :
256
- atol , rtol = None , None
257
- torch .testing .assert_close (v1 , v2 , atol = atol , rtol = rtol )
258
- except Exception as e :
259
- print ("debug:" , k , v1 , v2 )
260
- raise e
261
- print ("state dict testing single_gpu vs FSDP success" )
189
+ if not torch .cuda .is_available ():
190
+ warnings .warn ("CUDA not available, running in emulation_mode" )
191
+ emulate = True
192
+ elif torch .cuda .get_device_capability () < (9 , 0 ):
193
+ warnings .warn (
194
+ f"CUDA capability { torch .cuda .get_device_capability ()} < (9.0), running in emulation mode"
195
+ )
196
+ emulate = True
197
+
198
+ WORLD_SIZE = torch .cuda .device_count ()
199
+ args = (emulate , base_dtype , compile_fsdp , use_weight_dynamic_scaling )
200
+ mp .spawn (fsdp_main , args = (WORLD_SIZE , args ), nprocs = WORLD_SIZE , join = True )
262
201
263
202
264
203
if __name__ == "__main__" :
0 commit comments