@@ -185,103 +185,122 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
185
185
)
186
186
return
187
187
188
- # Loop over all fp8 layers and grab the needed tensors
189
- fp8_amax_x_tensor_list = [None ] * len (fp8_layers )
190
- fp8_amax_w_tensor_list = [None ] * len (fp8_layers )
191
- fp8_amax_dL_dY_tensor_list = [None ] * len (fp8_layers )
192
-
193
- fp8_x_amax_history_stack = [None ] * len (fp8_layers )
194
- fp8_w_amax_history_stack = [None ] * len (fp8_layers )
195
- fp8_dL_dY_amax_history_stack = [None ] * len (fp8_layers )
196
-
197
- x_dtypes = set ()
198
- scale_fn_recipes = set ()
199
-
200
- for idx , child in enumerate (fp8_layers ):
201
- fp8_amax_x_tensor_list [idx ] = child .fp8_amax_x
202
- fp8_amax_w_tensor_list [idx ] = child .fp8_amax_w
203
- fp8_amax_dL_dY_tensor_list [idx ] = child .fp8_amax_dL_dY
204
-
205
- fp8_x_amax_history_stack [idx ] = child .fp8_amax_history_x
206
- fp8_w_amax_history_stack [idx ] = child .fp8_amax_history_w
207
- fp8_dL_dY_amax_history_stack [idx ] = child .fp8_amax_history_dL_dY
208
-
209
- x_dtypes .add (child .last_seen_input_dtype )
210
- scale_fn_recipes .add (child .recipe .scale_fn_name )
211
-
212
- # TODO This way to get the activation dtype is not ideal
213
- if len (x_dtypes ) != 1 :
214
- raise ValueError (
215
- f"All layers must have the same last seen input_dtype, got { x_dtypes } "
216
- )
217
- x_dtype = next (iter (x_dtypes ))
188
+ def inner_func ():
189
+ """Why do we have this inner_function?
190
+
191
+ There are two portions of the outer sync_function that cause graph_breaks:
192
+ 1. The `get_float8_layers` call can cause graph breaks if the user did not pass
193
+ in the fp8_layers.
194
+ 2. At the end of syncing all the amaxes and scales we set the attr on the module
195
+ signaling that we have synced the amaxes and scales and the next forward can be run.
196
+ # TODO Maybe we should remove this safety check to remove the graph break?
197
+
198
+ By having this inner function, we can ensure that although the outer function may cause graph breaks
199
+ the inner function will not.
200
+ """
201
+ # Loop over all fp8 layers and grab the needed tensors
202
+ fp8_amax_x_tensor_list = [None ] * len (fp8_layers )
203
+ fp8_amax_w_tensor_list = [None ] * len (fp8_layers )
204
+ fp8_amax_dL_dY_tensor_list = [None ] * len (fp8_layers )
205
+
206
+ fp8_x_amax_history_stack = [None ] * len (fp8_layers )
207
+ fp8_w_amax_history_stack = [None ] * len (fp8_layers )
208
+ fp8_dL_dY_amax_history_stack = [None ] * len (fp8_layers )
209
+
210
+ x_dtypes = set ()
211
+ scale_fn_recipes = set ()
218
212
219
- if len (scale_fn_recipes ) != 1 :
220
- raise ValueError (
221
- f"All layers must have the same scale_fn recipe, got { scale_fn_recipes } "
222
- )
223
- scale_fn_recipe = next (iter (scale_fn_recipes ))
213
+ for idx , child in enumerate (fp8_layers ):
214
+ fp8_amax_x_tensor_list [idx ] = child .fp8_amax_x
215
+ fp8_amax_w_tensor_list [idx ] = child .fp8_amax_w
216
+ fp8_amax_dL_dY_tensor_list [idx ] = child .fp8_amax_dL_dY
224
217
225
- assert (
226
- len (fp8_amax_x_tensor_list )
227
- == len (fp8_amax_w_tensor_list )
228
- == len (fp8_amax_dL_dY_tensor_list )
229
- ), "Mismatched lengths of amax tensors."
230
-
231
- if dist .is_initialized ():
232
- # Combine all the amax tensors into one tensor and reduce it
233
- all_amax_tensors = torch .cat (
234
- fp8_amax_x_tensor_list + fp8_amax_w_tensor_list + fp8_amax_dL_dY_tensor_list
218
+ fp8_x_amax_history_stack [idx ] = child .fp8_amax_history_x
219
+ fp8_w_amax_history_stack [idx ] = child .fp8_amax_history_w
220
+ fp8_dL_dY_amax_history_stack [idx ] = child .fp8_amax_history_dL_dY
221
+
222
+ x_dtypes .add (child .last_seen_input_dtype )
223
+ scale_fn_recipes .add (child .recipe .scale_fn_name )
224
+
225
+ # TODO This way to get the activation dtype is not ideal
226
+ if len (x_dtypes ) != 1 :
227
+ raise ValueError (
228
+ f"All layers must have the same last seen input_dtype, got { x_dtypes } "
229
+ )
230
+ x_dtype = next (iter (x_dtypes ))
231
+
232
+ if len (scale_fn_recipes ) != 1 :
233
+ raise ValueError (
234
+ f"All layers must have the same scale_fn recipe, got { scale_fn_recipes } "
235
+ )
236
+ scale_fn_recipe = next (iter (scale_fn_recipes ))
237
+
238
+ assert (
239
+ len (fp8_amax_x_tensor_list )
240
+ == len (fp8_amax_w_tensor_list )
241
+ == len (fp8_amax_dL_dY_tensor_list )
242
+ ), "Mismatched lengths of amax tensors."
243
+
244
+ if dist .is_initialized ():
245
+ # Combine all the amax tensors into one tensor and reduce it
246
+ all_amax_tensors = torch .cat (
247
+ fp8_amax_x_tensor_list
248
+ + fp8_amax_w_tensor_list
249
+ + fp8_amax_dL_dY_tensor_list
250
+ )
251
+ all_reduced_amax_tensor = all_reduce (
252
+ all_amax_tensors , "MAX" , list (range (dist .get_world_size ()))
253
+ )
254
+ if isinstance (all_reduced_amax_tensor , AsyncCollectiveTensor ):
255
+ all_reduced_amax_tensor = all_reduced_amax_tensor .wait ()
256
+
257
+ (
258
+ reduced_fp8_amax_tensor ,
259
+ reduced_fp8_amax_w_tensor ,
260
+ reduced_fp8_amax_dL_dY_tensor ,
261
+ ) = torch .split (all_reduced_amax_tensor , len (fp8_amax_x_tensor_list ))
262
+
263
+ for idx , child in enumerate (fp8_layers ):
264
+ child .fp8_amax_x .copy_ (reduced_fp8_amax_tensor [idx ])
265
+ child .fp8_amax_w .copy_ (reduced_fp8_amax_w_tensor [idx ])
266
+ child .fp8_amax_dL_dY .copy_ (reduced_fp8_amax_dL_dY_tensor [idx ])
267
+
268
+ # We create two stacked tensor groups, one for the amax history and one for the current scales
269
+ fp8_amax_x_tensors = torch .vstack (fp8_amax_x_tensor_list )
270
+ fp8_amax_w_tensors = torch .vstack (fp8_amax_w_tensor_list )
271
+ fp8_amax_dL_dY_tensors = torch .vstack (fp8_amax_dL_dY_tensor_list )
272
+
273
+ fp8_x_amax_history_stack = torch .vstack (fp8_x_amax_history_stack )
274
+ fp8_w_amax_history_stack = torch .vstack (fp8_w_amax_history_stack )
275
+ fp8_dL_dY_amax_history_stack = torch .vstack (fp8_dL_dY_amax_history_stack )
276
+
277
+ # Update the history stacks with the new amax values
278
+ _update_history_stack (fp8_amax_x_tensors , fp8_x_amax_history_stack )
279
+ _update_history_stack (fp8_amax_w_tensors , fp8_w_amax_history_stack )
280
+ _update_history_stack (fp8_amax_dL_dY_tensors , fp8_dL_dY_amax_history_stack )
281
+
282
+ # Calculate the new scales from the updated history stacks
283
+ new_x_scales = amax_history_to_scale_stack (
284
+ fp8_x_amax_history_stack , torch .float8_e4m3fn , x_dtype , scale_fn_recipe
235
285
)
236
- all_reduced_amax_tensor = all_reduce (
237
- all_amax_tensors , "MAX" , list (range (dist .get_world_size ()))
286
+ new_w_scales = amax_history_to_scale_stack (
287
+ fp8_w_amax_history_stack , torch .float8_e4m3fn , x_dtype , scale_fn_recipe
288
+ )
289
+ new_dL_dY_scales = amax_history_to_scale_stack (
290
+ fp8_dL_dY_amax_history_stack , torch .float8_e5m2 , x_dtype , scale_fn_recipe
238
291
)
239
- if isinstance (all_reduced_amax_tensor , AsyncCollectiveTensor ):
240
- all_reduced_amax_tensor = all_reduced_amax_tensor .wait ()
241
-
242
- (
243
- reduced_fp8_amax_tensor ,
244
- reduced_fp8_amax_w_tensor ,
245
- reduced_fp8_amax_dL_dY_tensor ,
246
- ) = torch .split (all_reduced_amax_tensor , len (fp8_amax_x_tensor_list ))
247
292
293
+ # Iterate through the layers and update the scales, and set the flag to signal that the amaxes/scales are ready
248
294
for idx , child in enumerate (fp8_layers ):
249
- child .fp8_amax_x .copy_ (reduced_fp8_amax_tensor [idx ])
250
- child .fp8_amax_w .copy_ (reduced_fp8_amax_w_tensor [idx ])
251
- child .fp8_amax_dL_dY .copy_ (reduced_fp8_amax_dL_dY_tensor [idx ])
252
-
253
- # We create two stacked tensor groups, one for the amax history and one for the current scales
254
- fp8_amax_x_tensors = torch .vstack (fp8_amax_x_tensor_list )
255
- fp8_amax_w_tensors = torch .vstack (fp8_amax_w_tensor_list )
256
- fp8_amax_dL_dY_tensors = torch .vstack (fp8_amax_dL_dY_tensor_list )
257
-
258
- fp8_x_amax_history_stack = torch .vstack (fp8_x_amax_history_stack )
259
- fp8_w_amax_history_stack = torch .vstack (fp8_w_amax_history_stack )
260
- fp8_dL_dY_amax_history_stack = torch .vstack (fp8_dL_dY_amax_history_stack )
261
-
262
- # Update the history stacks with the new amax values
263
- _update_history_stack (fp8_amax_x_tensors , fp8_x_amax_history_stack )
264
- _update_history_stack (fp8_amax_w_tensors , fp8_w_amax_history_stack )
265
- _update_history_stack (fp8_amax_dL_dY_tensors , fp8_dL_dY_amax_history_stack )
266
-
267
- # Calculate the new scales from the updated history stacks
268
- new_x_scales = amax_history_to_scale_stack (
269
- fp8_x_amax_history_stack , torch .float8_e4m3fn , x_dtype , scale_fn_recipe
270
- )
271
- new_w_scales = amax_history_to_scale_stack (
272
- fp8_w_amax_history_stack , torch .float8_e4m3fn , x_dtype , scale_fn_recipe
273
- )
274
- new_dL_dY_scales = amax_history_to_scale_stack (
275
- fp8_dL_dY_amax_history_stack , torch .float8_e5m2 , x_dtype , scale_fn_recipe
276
- )
295
+ child .fp8_scale_x .copy_ (new_x_scales [idx ])
296
+ child .fp8_scale_w .copy_ (new_w_scales [idx ])
297
+ child .fp8_scale_dL_dY .copy_ (new_dL_dY_scales [idx ])
277
298
278
- # Iterate through the layers and update the scales, and set the flag to signal that the amaxes/scales are ready
279
- for idx , child in enumerate (fp8_layers ):
280
- child .fp8_scale_x .copy_ (new_x_scales [idx ])
281
- child .fp8_scale_w .copy_ (new_w_scales [idx ])
282
- child .fp8_scale_dL_dY .copy_ (new_dL_dY_scales [idx ])
299
+ # This allows for the compile to succede on the inner func and fail on the graph breaks
300
+ # at the beginning and and of syncing
301
+ inner_func ()
283
302
303
+ for child in fp8_layers :
284
304
# 4. set a flag to signal amaxes/scales are ready
285
305
# We only update the flag if we know it will be checked by the modules
286
- if fp8_config .enable_amax_init :
287
- child .amax_and_scale_synced = True
306
+ child .amax_and_scale_synced = True
0 commit comments