-
Notifications
You must be signed in to change notification settings - Fork 19
Add a compiled inner func to sync amac #221
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -185,103 +185,122 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) | |
) | ||
return | ||
|
||
# Loop over all fp8 layers and grab the needed tensors | ||
fp8_amax_x_tensor_list = [None] * len(fp8_layers) | ||
fp8_amax_w_tensor_list = [None] * len(fp8_layers) | ||
fp8_amax_dL_dY_tensor_list = [None] * len(fp8_layers) | ||
|
||
fp8_x_amax_history_stack = [None] * len(fp8_layers) | ||
fp8_w_amax_history_stack = [None] * len(fp8_layers) | ||
fp8_dL_dY_amax_history_stack = [None] * len(fp8_layers) | ||
|
||
x_dtypes = set() | ||
scale_fn_recipes = set() | ||
|
||
for idx, child in enumerate(fp8_layers): | ||
fp8_amax_x_tensor_list[idx] = child.fp8_amax_x | ||
fp8_amax_w_tensor_list[idx] = child.fp8_amax_w | ||
fp8_amax_dL_dY_tensor_list[idx] = child.fp8_amax_dL_dY | ||
|
||
fp8_x_amax_history_stack[idx] = child.fp8_amax_history_x | ||
fp8_w_amax_history_stack[idx] = child.fp8_amax_history_w | ||
fp8_dL_dY_amax_history_stack[idx] = child.fp8_amax_history_dL_dY | ||
|
||
x_dtypes.add(child.last_seen_input_dtype) | ||
scale_fn_recipes.add(child.recipe.scale_fn_name) | ||
|
||
# TODO This way to get the activation dtype is not ideal | ||
if len(x_dtypes) != 1: | ||
raise ValueError( | ||
f"All layers must have the same last seen input_dtype, got {x_dtypes}" | ||
) | ||
x_dtype = next(iter(x_dtypes)) | ||
def inner_func(): | ||
"""Why do we have this inner_function? | ||
|
||
There are two portions of the outer sync_function that cause graph_breaks: | ||
1. The `get_float8_layers` call can cause graph breaks if the user did not pass | ||
in the fp8_layers. | ||
2. At the end of syncing all the amaxes and scales we set the attr on the module | ||
signaling that we have synced the amaxes and scales and the next forward can be run. | ||
# TODO Maybe we should remove this safety check to remove the graph break? | ||
|
||
By having this inner function, we can ensure that although the outer function may cause graph breaks | ||
the inner function will not. | ||
""" | ||
# Loop over all fp8 layers and grab the needed tensors | ||
fp8_amax_x_tensor_list = [None] * len(fp8_layers) | ||
fp8_amax_w_tensor_list = [None] * len(fp8_layers) | ||
fp8_amax_dL_dY_tensor_list = [None] * len(fp8_layers) | ||
|
||
fp8_x_amax_history_stack = [None] * len(fp8_layers) | ||
fp8_w_amax_history_stack = [None] * len(fp8_layers) | ||
fp8_dL_dY_amax_history_stack = [None] * len(fp8_layers) | ||
|
||
x_dtypes = set() | ||
scale_fn_recipes = set() | ||
|
||
if len(scale_fn_recipes) != 1: | ||
raise ValueError( | ||
f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}" | ||
) | ||
scale_fn_recipe = next(iter(scale_fn_recipes)) | ||
for idx, child in enumerate(fp8_layers): | ||
fp8_amax_x_tensor_list[idx] = child.fp8_amax_x | ||
fp8_amax_w_tensor_list[idx] = child.fp8_amax_w | ||
fp8_amax_dL_dY_tensor_list[idx] = child.fp8_amax_dL_dY | ||
|
||
assert ( | ||
len(fp8_amax_x_tensor_list) | ||
== len(fp8_amax_w_tensor_list) | ||
== len(fp8_amax_dL_dY_tensor_list) | ||
), "Mismatched lengths of amax tensors." | ||
|
||
if dist.is_initialized(): | ||
# Combine all the amax tensors into one tensor and reduce it | ||
all_amax_tensors = torch.cat( | ||
fp8_amax_x_tensor_list + fp8_amax_w_tensor_list + fp8_amax_dL_dY_tensor_list | ||
fp8_x_amax_history_stack[idx] = child.fp8_amax_history_x | ||
fp8_w_amax_history_stack[idx] = child.fp8_amax_history_w | ||
fp8_dL_dY_amax_history_stack[idx] = child.fp8_amax_history_dL_dY | ||
|
||
x_dtypes.add(child.last_seen_input_dtype) | ||
scale_fn_recipes.add(child.recipe.scale_fn_name) | ||
|
||
# TODO This way to get the activation dtype is not ideal | ||
if len(x_dtypes) != 1: | ||
raise ValueError( | ||
f"All layers must have the same last seen input_dtype, got {x_dtypes}" | ||
) | ||
x_dtype = next(iter(x_dtypes)) | ||
|
||
if len(scale_fn_recipes) != 1: | ||
raise ValueError( | ||
f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}" | ||
) | ||
scale_fn_recipe = next(iter(scale_fn_recipes)) | ||
|
||
assert ( | ||
len(fp8_amax_x_tensor_list) | ||
== len(fp8_amax_w_tensor_list) | ||
== len(fp8_amax_dL_dY_tensor_list) | ||
), "Mismatched lengths of amax tensors." | ||
|
||
if dist.is_initialized(): | ||
# Combine all the amax tensors into one tensor and reduce it | ||
all_amax_tensors = torch.cat( | ||
fp8_amax_x_tensor_list | ||
+ fp8_amax_w_tensor_list | ||
+ fp8_amax_dL_dY_tensor_list | ||
) | ||
all_reduced_amax_tensor = all_reduce( | ||
all_amax_tensors, "MAX", list(range(dist.get_world_size())) | ||
) | ||
if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor): | ||
all_reduced_amax_tensor = all_reduced_amax_tensor.wait() | ||
|
||
( | ||
reduced_fp8_amax_tensor, | ||
reduced_fp8_amax_w_tensor, | ||
reduced_fp8_amax_dL_dY_tensor, | ||
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list)) | ||
|
||
for idx, child in enumerate(fp8_layers): | ||
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx]) | ||
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx]) | ||
child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx]) | ||
|
||
# We create two stacked tensor groups, one for the amax history and one for the current scales | ||
fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list) | ||
fp8_amax_w_tensors = torch.vstack(fp8_amax_w_tensor_list) | ||
fp8_amax_dL_dY_tensors = torch.vstack(fp8_amax_dL_dY_tensor_list) | ||
|
||
fp8_x_amax_history_stack = torch.vstack(fp8_x_amax_history_stack) | ||
fp8_w_amax_history_stack = torch.vstack(fp8_w_amax_history_stack) | ||
fp8_dL_dY_amax_history_stack = torch.vstack(fp8_dL_dY_amax_history_stack) | ||
|
||
# Update the history stacks with the new amax values | ||
_update_history_stack(fp8_amax_x_tensors, fp8_x_amax_history_stack) | ||
_update_history_stack(fp8_amax_w_tensors, fp8_w_amax_history_stack) | ||
_update_history_stack(fp8_amax_dL_dY_tensors, fp8_dL_dY_amax_history_stack) | ||
|
||
# Calculate the new scales from the updated history stacks | ||
new_x_scales = amax_history_to_scale_stack( | ||
fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe | ||
) | ||
all_reduced_amax_tensor = all_reduce( | ||
all_amax_tensors, "MAX", list(range(dist.get_world_size())) | ||
new_w_scales = amax_history_to_scale_stack( | ||
fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe | ||
) | ||
new_dL_dY_scales = amax_history_to_scale_stack( | ||
fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe | ||
) | ||
if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor): | ||
all_reduced_amax_tensor = all_reduced_amax_tensor.wait() | ||
|
||
( | ||
reduced_fp8_amax_tensor, | ||
reduced_fp8_amax_w_tensor, | ||
reduced_fp8_amax_dL_dY_tensor, | ||
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list)) | ||
|
||
# Iterate through the layers and update the scales, and set the flag to signal that the amaxes/scales are ready | ||
for idx, child in enumerate(fp8_layers): | ||
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx]) | ||
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx]) | ||
child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx]) | ||
|
||
# We create two stacked tensor groups, one for the amax history and one for the current scales | ||
fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list) | ||
fp8_amax_w_tensors = torch.vstack(fp8_amax_w_tensor_list) | ||
fp8_amax_dL_dY_tensors = torch.vstack(fp8_amax_dL_dY_tensor_list) | ||
|
||
fp8_x_amax_history_stack = torch.vstack(fp8_x_amax_history_stack) | ||
fp8_w_amax_history_stack = torch.vstack(fp8_w_amax_history_stack) | ||
fp8_dL_dY_amax_history_stack = torch.vstack(fp8_dL_dY_amax_history_stack) | ||
|
||
# Update the history stacks with the new amax values | ||
_update_history_stack(fp8_amax_x_tensors, fp8_x_amax_history_stack) | ||
_update_history_stack(fp8_amax_w_tensors, fp8_w_amax_history_stack) | ||
_update_history_stack(fp8_amax_dL_dY_tensors, fp8_dL_dY_amax_history_stack) | ||
|
||
# Calculate the new scales from the updated history stacks | ||
new_x_scales = amax_history_to_scale_stack( | ||
fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe | ||
) | ||
new_w_scales = amax_history_to_scale_stack( | ||
fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe | ||
) | ||
new_dL_dY_scales = amax_history_to_scale_stack( | ||
fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe | ||
) | ||
child.fp8_scale_x.copy_(new_x_scales[idx]) | ||
child.fp8_scale_w.copy_(new_w_scales[idx]) | ||
child.fp8_scale_dL_dY.copy_(new_dL_dY_scales[idx]) | ||
|
||
# Iterate through the layers and update the scales, and set the flag to signal that the amaxes/scales are ready | ||
for idx, child in enumerate(fp8_layers): | ||
child.fp8_scale_x.copy_(new_x_scales[idx]) | ||
child.fp8_scale_w.copy_(new_w_scales[idx]) | ||
child.fp8_scale_dL_dY.copy_(new_dL_dY_scales[idx]) | ||
# This allows for the compile to succede on the inner func and fail on the graph breaks | ||
# at the beginning and and of syncing | ||
inner_func() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a little hard to tell but I think this is all code motion? (moving it inside of inner_func()) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah tabs strike again here, but just indented one layer |
||
|
||
for child in fp8_layers: | ||
# 4. set a flag to signal amaxes/scales are ready | ||
# We only update the flag if we know it will be checked by the modules | ||
if fp8_config.enable_amax_init: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this the thing where you mentioned that this config is actually needed for correctness There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was andrews original comment here: #220 (comment) I used the wrong config setting |
||
child.amax_and_scale_synced = True | ||
child.amax_and_scale_synced = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add a comment on why we want the inner fn? (we have graph breaks that we want to fix eventually, but in the meantime we want to ensure that everything in inner_func is compiled)