@@ -185,102 +185,110 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
185
185
)
186
186
return
187
187
188
- # Loop over all fp8 layers and grab the needed tensors
189
- fp8_amax_x_tensor_list = [None ] * len (fp8_layers )
190
- fp8_amax_w_tensor_list = [None ] * len (fp8_layers )
191
- fp8_amax_dL_dY_tensor_list = [None ] * len (fp8_layers )
192
-
193
- fp8_x_amax_history_stack = [None ] * len (fp8_layers )
194
- fp8_w_amax_history_stack = [None ] * len (fp8_layers )
195
- fp8_dL_dY_amax_history_stack = [None ] * len (fp8_layers )
196
-
197
- x_dtypes = set ()
198
- scale_fn_recipes = set ()
199
-
200
- for idx , child in enumerate (fp8_layers ):
201
- fp8_amax_x_tensor_list [idx ] = child .fp8_amax_x
202
- fp8_amax_w_tensor_list [idx ] = child .fp8_amax_w
203
- fp8_amax_dL_dY_tensor_list [idx ] = child .fp8_amax_dL_dY
204
-
205
- fp8_x_amax_history_stack [idx ] = child .fp8_amax_history_x
206
- fp8_w_amax_history_stack [idx ] = child .fp8_amax_history_w
207
- fp8_dL_dY_amax_history_stack [idx ] = child .fp8_amax_history_dL_dY
208
-
209
- x_dtypes .add (child .last_seen_input_dtype )
210
- scale_fn_recipes .add (child .recipe .scale_fn_name )
211
-
212
- # TODO This way to get the activation dtype is not ideal
213
- if len (x_dtypes ) != 1 :
214
- raise ValueError (
215
- f"All layers must have the same last seen input_dtype, got { x_dtypes } "
216
- )
217
- x_dtype = next (iter (x_dtypes ))
188
+ def inner_func ():
189
+ # Loop over all fp8 layers and grab the needed tensors
190
+ fp8_amax_x_tensor_list = [None ] * len (fp8_layers )
191
+ fp8_amax_w_tensor_list = [None ] * len (fp8_layers )
192
+ fp8_amax_dL_dY_tensor_list = [None ] * len (fp8_layers )
218
193
219
- if len (scale_fn_recipes ) != 1 :
220
- raise ValueError (
221
- f"All layers must have the same scale_fn recipe, got { scale_fn_recipes } "
222
- )
223
- scale_fn_recipe = next (iter (scale_fn_recipes ))
194
+ fp8_x_amax_history_stack = [None ] * len (fp8_layers )
195
+ fp8_w_amax_history_stack = [None ] * len (fp8_layers )
196
+ fp8_dL_dY_amax_history_stack = [None ] * len (fp8_layers )
224
197
225
- assert (
226
- len (fp8_amax_x_tensor_list )
227
- == len (fp8_amax_w_tensor_list )
228
- == len (fp8_amax_dL_dY_tensor_list )
229
- ), "Mismatched lengths of amax tensors."
230
-
231
- if dist .is_initialized ():
232
- # Combine all the amax tensors into one tensor and reduce it
233
- all_amax_tensors = torch .cat (
234
- fp8_amax_x_tensor_list + fp8_amax_w_tensor_list + fp8_amax_dL_dY_tensor_list
198
+ x_dtypes = set ()
199
+ scale_fn_recipes = set ()
200
+
201
+ for idx , child in enumerate (fp8_layers ):
202
+ fp8_amax_x_tensor_list [idx ] = child .fp8_amax_x
203
+ fp8_amax_w_tensor_list [idx ] = child .fp8_amax_w
204
+ fp8_amax_dL_dY_tensor_list [idx ] = child .fp8_amax_dL_dY
205
+
206
+ fp8_x_amax_history_stack [idx ] = child .fp8_amax_history_x
207
+ fp8_w_amax_history_stack [idx ] = child .fp8_amax_history_w
208
+ fp8_dL_dY_amax_history_stack [idx ] = child .fp8_amax_history_dL_dY
209
+
210
+ x_dtypes .add (child .last_seen_input_dtype )
211
+ scale_fn_recipes .add (child .recipe .scale_fn_name )
212
+
213
+ # TODO This way to get the activation dtype is not ideal
214
+ if len (x_dtypes ) != 1 :
215
+ raise ValueError (
216
+ f"All layers must have the same last seen input_dtype, got { x_dtypes } "
217
+ )
218
+ x_dtype = next (iter (x_dtypes ))
219
+
220
+ if len (scale_fn_recipes ) != 1 :
221
+ raise ValueError (
222
+ f"All layers must have the same scale_fn recipe, got { scale_fn_recipes } "
223
+ )
224
+ scale_fn_recipe = next (iter (scale_fn_recipes ))
225
+
226
+ assert (
227
+ len (fp8_amax_x_tensor_list )
228
+ == len (fp8_amax_w_tensor_list )
229
+ == len (fp8_amax_dL_dY_tensor_list )
230
+ ), "Mismatched lengths of amax tensors."
231
+
232
+ if dist .is_initialized ():
233
+ # Combine all the amax tensors into one tensor and reduce it
234
+ all_amax_tensors = torch .cat (
235
+ fp8_amax_x_tensor_list
236
+ + fp8_amax_w_tensor_list
237
+ + fp8_amax_dL_dY_tensor_list
238
+ )
239
+ all_reduced_amax_tensor = all_reduce (
240
+ all_amax_tensors , "MAX" , list (range (dist .get_world_size ()))
241
+ )
242
+ if isinstance (all_reduced_amax_tensor , AsyncCollectiveTensor ):
243
+ all_reduced_amax_tensor = all_reduced_amax_tensor .wait ()
244
+
245
+ (
246
+ reduced_fp8_amax_tensor ,
247
+ reduced_fp8_amax_w_tensor ,
248
+ reduced_fp8_amax_dL_dY_tensor ,
249
+ ) = torch .split (all_reduced_amax_tensor , len (fp8_amax_x_tensor_list ))
250
+
251
+ for idx , child in enumerate (fp8_layers ):
252
+ child .fp8_amax_x .copy_ (reduced_fp8_amax_tensor [idx ])
253
+ child .fp8_amax_w .copy_ (reduced_fp8_amax_w_tensor [idx ])
254
+ child .fp8_amax_dL_dY .copy_ (reduced_fp8_amax_dL_dY_tensor [idx ])
255
+
256
+ # We create two stacked tensor groups, one for the amax history and one for the current scales
257
+ fp8_amax_x_tensors = torch .vstack (fp8_amax_x_tensor_list )
258
+ fp8_amax_w_tensors = torch .vstack (fp8_amax_w_tensor_list )
259
+ fp8_amax_dL_dY_tensors = torch .vstack (fp8_amax_dL_dY_tensor_list )
260
+
261
+ fp8_x_amax_history_stack = torch .vstack (fp8_x_amax_history_stack )
262
+ fp8_w_amax_history_stack = torch .vstack (fp8_w_amax_history_stack )
263
+ fp8_dL_dY_amax_history_stack = torch .vstack (fp8_dL_dY_amax_history_stack )
264
+
265
+ # Update the history stacks with the new amax values
266
+ _update_history_stack (fp8_amax_x_tensors , fp8_x_amax_history_stack )
267
+ _update_history_stack (fp8_amax_w_tensors , fp8_w_amax_history_stack )
268
+ _update_history_stack (fp8_amax_dL_dY_tensors , fp8_dL_dY_amax_history_stack )
269
+
270
+ # Calculate the new scales from the updated history stacks
271
+ new_x_scales = amax_history_to_scale_stack (
272
+ fp8_x_amax_history_stack , torch .float8_e4m3fn , x_dtype , scale_fn_recipe
235
273
)
236
- all_reduced_amax_tensor = all_reduce (
237
- all_amax_tensors , "MAX" , list (range (dist .get_world_size ()))
274
+ new_w_scales = amax_history_to_scale_stack (
275
+ fp8_w_amax_history_stack , torch .float8_e4m3fn , x_dtype , scale_fn_recipe
276
+ )
277
+ new_dL_dY_scales = amax_history_to_scale_stack (
278
+ fp8_dL_dY_amax_history_stack , torch .float8_e5m2 , x_dtype , scale_fn_recipe
238
279
)
239
- if isinstance (all_reduced_amax_tensor , AsyncCollectiveTensor ):
240
- all_reduced_amax_tensor = all_reduced_amax_tensor .wait ()
241
-
242
- (
243
- reduced_fp8_amax_tensor ,
244
- reduced_fp8_amax_w_tensor ,
245
- reduced_fp8_amax_dL_dY_tensor ,
246
- ) = torch .split (all_reduced_amax_tensor , len (fp8_amax_x_tensor_list ))
247
280
281
+ # Iterate through the layers and update the scales, and set the flag to signal that the amaxes/scales are ready
248
282
for idx , child in enumerate (fp8_layers ):
249
- child .fp8_amax_x .copy_ (reduced_fp8_amax_tensor [idx ])
250
- child .fp8_amax_w .copy_ (reduced_fp8_amax_w_tensor [idx ])
251
- child .fp8_amax_dL_dY .copy_ (reduced_fp8_amax_dL_dY_tensor [idx ])
252
-
253
- # We create two stacked tensor groups, one for the amax history and one for the current scales
254
- fp8_amax_x_tensors = torch .vstack (fp8_amax_x_tensor_list )
255
- fp8_amax_w_tensors = torch .vstack (fp8_amax_w_tensor_list )
256
- fp8_amax_dL_dY_tensors = torch .vstack (fp8_amax_dL_dY_tensor_list )
257
-
258
- fp8_x_amax_history_stack = torch .vstack (fp8_x_amax_history_stack )
259
- fp8_w_amax_history_stack = torch .vstack (fp8_w_amax_history_stack )
260
- fp8_dL_dY_amax_history_stack = torch .vstack (fp8_dL_dY_amax_history_stack )
261
-
262
- # Update the history stacks with the new amax values
263
- _update_history_stack (fp8_amax_x_tensors , fp8_x_amax_history_stack )
264
- _update_history_stack (fp8_amax_w_tensors , fp8_w_amax_history_stack )
265
- _update_history_stack (fp8_amax_dL_dY_tensors , fp8_dL_dY_amax_history_stack )
266
-
267
- # Calculate the new scales from the updated history stacks
268
- new_x_scales = amax_history_to_scale_stack (
269
- fp8_x_amax_history_stack , torch .float8_e4m3fn , x_dtype , scale_fn_recipe
270
- )
271
- new_w_scales = amax_history_to_scale_stack (
272
- fp8_w_amax_history_stack , torch .float8_e4m3fn , x_dtype , scale_fn_recipe
273
- )
274
- new_dL_dY_scales = amax_history_to_scale_stack (
275
- fp8_dL_dY_amax_history_stack , torch .float8_e5m2 , x_dtype , scale_fn_recipe
276
- )
283
+ child .fp8_scale_x .copy_ (new_x_scales [idx ])
284
+ child .fp8_scale_w .copy_ (new_w_scales [idx ])
285
+ child .fp8_scale_dL_dY .copy_ (new_dL_dY_scales [idx ])
277
286
278
- # Iterate through the layers and update the scales, and set the flag to signal that the amaxes/scales are ready
279
- for idx , child in enumerate (fp8_layers ):
280
- child .fp8_scale_x .copy_ (new_x_scales [idx ])
281
- child .fp8_scale_w .copy_ (new_w_scales [idx ])
282
- child .fp8_scale_dL_dY .copy_ (new_dL_dY_scales [idx ])
287
+ # When cuda-graphs work we should compile with "reduce-overhead"
288
+ compiled_inner_func = torch .compile (inner_func )
289
+ compiled_inner_func ()
283
290
291
+ for child in fp8_layers :
284
292
# 4. set a flag to signal amaxes/scales are ready
285
293
# We only update the flag if we know it will be checked by the modules
286
294
if fp8_config .enable_amax_init :
0 commit comments