@@ -164,6 +164,24 @@ def get_float8_layers(model: torch.nn.Module):
164
164
return fp8_layers
165
165
166
166
167
+ def get_float8_layers_dtype (model : torch .nn .Module ):
168
+ """Iterates through the model and returns all the Float8Linear layers.
169
+ Args:
170
+ model (torch.nn.Module): The model to look for Float8Linear layers in.
171
+ """
172
+ fp8_dtype_fw = set ()
173
+ fp8_dtype_bw = set ()
174
+ # Get all fp8 layers and tensors
175
+ for child in model .modules ():
176
+ if isinstance (child , Float8Linear ):
177
+ fp8_dtype_fw .add (child .fp8_dtype_fw )
178
+ fp8_dtype_bw .add (child .fp8_dtype_bw )
179
+
180
+ assert len (fp8_dtype_fw ) == 1 , "All fp8 layers must have the same fp8_dtype_fw"
181
+ assert len (fp8_dtype_bw ) == 1 , "All fp8 layers must have the same fp8_dtype_bw"
182
+ return fp8_dtype_fw .pop (), fp8_dtype_bw .pop ()
183
+
184
+
167
185
@torch .no_grad ()
168
186
def sync_float8_amax_and_scale_history (model : torch .nn .Module , fp8_layers = None ) -> None :
169
187
"""
@@ -197,6 +215,8 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
197
215
)
198
216
return
199
217
218
+ fp8_dtype_fw , fp8_dtype_bw = get_float8_layers_dtype (model )
219
+
200
220
def inner_func ():
201
221
"""Why do we have this inner_function?
202
222
@@ -293,13 +313,13 @@ def inner_func():
293
313
294
314
# Calculate the new scales from the updated history stacks
295
315
new_x_scales = amax_history_to_scale_stack (
296
- fp8_x_amax_history_stack , torch . float8_e4m3fn , x_dtype , scale_fn_recipe
316
+ fp8_x_amax_history_stack , fp8_dtype_fw , x_dtype , scale_fn_recipe
297
317
)
298
318
new_w_scales = amax_history_to_scale_stack (
299
- fp8_w_amax_history_stack , torch . float8_e4m3fn , x_dtype , scale_fn_recipe
319
+ fp8_w_amax_history_stack , fp8_dtype_fw , x_dtype , scale_fn_recipe
300
320
)
301
321
new_dL_dY_scales = amax_history_to_scale_stack (
302
- fp8_dL_dY_amax_history_stack , torch . float8_e5m2 , x_dtype , scale_fn_recipe
322
+ fp8_dL_dY_amax_history_stack , fp8_dtype_bw , x_dtype , scale_fn_recipe
303
323
)
304
324
305
325
# Iterate through the layers and update the scales
0 commit comments