Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Add a compiled inner func to sync amac #221

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 109 additions & 90 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,103 +185,122 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
)
return

# Loop over all fp8 layers and grab the needed tensors
fp8_amax_x_tensor_list = [None] * len(fp8_layers)
fp8_amax_w_tensor_list = [None] * len(fp8_layers)
fp8_amax_dL_dY_tensor_list = [None] * len(fp8_layers)

fp8_x_amax_history_stack = [None] * len(fp8_layers)
fp8_w_amax_history_stack = [None] * len(fp8_layers)
fp8_dL_dY_amax_history_stack = [None] * len(fp8_layers)

x_dtypes = set()
scale_fn_recipes = set()

for idx, child in enumerate(fp8_layers):
fp8_amax_x_tensor_list[idx] = child.fp8_amax_x
fp8_amax_w_tensor_list[idx] = child.fp8_amax_w
fp8_amax_dL_dY_tensor_list[idx] = child.fp8_amax_dL_dY

fp8_x_amax_history_stack[idx] = child.fp8_amax_history_x
fp8_w_amax_history_stack[idx] = child.fp8_amax_history_w
fp8_dL_dY_amax_history_stack[idx] = child.fp8_amax_history_dL_dY

x_dtypes.add(child.last_seen_input_dtype)
scale_fn_recipes.add(child.recipe.scale_fn_name)

# TODO This way to get the activation dtype is not ideal
if len(x_dtypes) != 1:
raise ValueError(
f"All layers must have the same last seen input_dtype, got {x_dtypes}"
)
x_dtype = next(iter(x_dtypes))
def inner_func():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a comment on why we want the inner fn? (we have graph breaks that we want to fix eventually, but in the meantime we want to ensure that everything in inner_func is compiled)

"""Why do we have this inner_function?

There are two portions of the outer sync_function that cause graph_breaks:
1. The `get_float8_layers` call can cause graph breaks if the user did not pass
in the fp8_layers.
2. At the end of syncing all the amaxes and scales we set the attr on the module
signaling that we have synced the amaxes and scales and the next forward can be run.
# TODO Maybe we should remove this safety check to remove the graph break?

By having this inner function, we can ensure that although the outer function may cause graph breaks
the inner function will not.
"""
# Loop over all fp8 layers and grab the needed tensors
fp8_amax_x_tensor_list = [None] * len(fp8_layers)
fp8_amax_w_tensor_list = [None] * len(fp8_layers)
fp8_amax_dL_dY_tensor_list = [None] * len(fp8_layers)

fp8_x_amax_history_stack = [None] * len(fp8_layers)
fp8_w_amax_history_stack = [None] * len(fp8_layers)
fp8_dL_dY_amax_history_stack = [None] * len(fp8_layers)

x_dtypes = set()
scale_fn_recipes = set()

if len(scale_fn_recipes) != 1:
raise ValueError(
f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}"
)
scale_fn_recipe = next(iter(scale_fn_recipes))
for idx, child in enumerate(fp8_layers):
fp8_amax_x_tensor_list[idx] = child.fp8_amax_x
fp8_amax_w_tensor_list[idx] = child.fp8_amax_w
fp8_amax_dL_dY_tensor_list[idx] = child.fp8_amax_dL_dY

assert (
len(fp8_amax_x_tensor_list)
== len(fp8_amax_w_tensor_list)
== len(fp8_amax_dL_dY_tensor_list)
), "Mismatched lengths of amax tensors."

if dist.is_initialized():
# Combine all the amax tensors into one tensor and reduce it
all_amax_tensors = torch.cat(
fp8_amax_x_tensor_list + fp8_amax_w_tensor_list + fp8_amax_dL_dY_tensor_list
fp8_x_amax_history_stack[idx] = child.fp8_amax_history_x
fp8_w_amax_history_stack[idx] = child.fp8_amax_history_w
fp8_dL_dY_amax_history_stack[idx] = child.fp8_amax_history_dL_dY

x_dtypes.add(child.last_seen_input_dtype)
scale_fn_recipes.add(child.recipe.scale_fn_name)

# TODO This way to get the activation dtype is not ideal
if len(x_dtypes) != 1:
raise ValueError(
f"All layers must have the same last seen input_dtype, got {x_dtypes}"
)
x_dtype = next(iter(x_dtypes))

if len(scale_fn_recipes) != 1:
raise ValueError(
f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}"
)
scale_fn_recipe = next(iter(scale_fn_recipes))

assert (
len(fp8_amax_x_tensor_list)
== len(fp8_amax_w_tensor_list)
== len(fp8_amax_dL_dY_tensor_list)
), "Mismatched lengths of amax tensors."

if dist.is_initialized():
# Combine all the amax tensors into one tensor and reduce it
all_amax_tensors = torch.cat(
fp8_amax_x_tensor_list
+ fp8_amax_w_tensor_list
+ fp8_amax_dL_dY_tensor_list
)
all_reduced_amax_tensor = all_reduce(
all_amax_tensors, "MAX", list(range(dist.get_world_size()))
)
if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor):
all_reduced_amax_tensor = all_reduced_amax_tensor.wait()

(
reduced_fp8_amax_tensor,
reduced_fp8_amax_w_tensor,
reduced_fp8_amax_dL_dY_tensor,
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list))

for idx, child in enumerate(fp8_layers):
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx])
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx])
child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx])

# We create two stacked tensor groups, one for the amax history and one for the current scales
fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list)
fp8_amax_w_tensors = torch.vstack(fp8_amax_w_tensor_list)
fp8_amax_dL_dY_tensors = torch.vstack(fp8_amax_dL_dY_tensor_list)

fp8_x_amax_history_stack = torch.vstack(fp8_x_amax_history_stack)
fp8_w_amax_history_stack = torch.vstack(fp8_w_amax_history_stack)
fp8_dL_dY_amax_history_stack = torch.vstack(fp8_dL_dY_amax_history_stack)

# Update the history stacks with the new amax values
_update_history_stack(fp8_amax_x_tensors, fp8_x_amax_history_stack)
_update_history_stack(fp8_amax_w_tensors, fp8_w_amax_history_stack)
_update_history_stack(fp8_amax_dL_dY_tensors, fp8_dL_dY_amax_history_stack)

# Calculate the new scales from the updated history stacks
new_x_scales = amax_history_to_scale_stack(
fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
)
all_reduced_amax_tensor = all_reduce(
all_amax_tensors, "MAX", list(range(dist.get_world_size()))
new_w_scales = amax_history_to_scale_stack(
fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
)
new_dL_dY_scales = amax_history_to_scale_stack(
fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe
)
if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor):
all_reduced_amax_tensor = all_reduced_amax_tensor.wait()

(
reduced_fp8_amax_tensor,
reduced_fp8_amax_w_tensor,
reduced_fp8_amax_dL_dY_tensor,
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list))

# Iterate through the layers and update the scales, and set the flag to signal that the amaxes/scales are ready
for idx, child in enumerate(fp8_layers):
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx])
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx])
child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx])

# We create two stacked tensor groups, one for the amax history and one for the current scales
fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list)
fp8_amax_w_tensors = torch.vstack(fp8_amax_w_tensor_list)
fp8_amax_dL_dY_tensors = torch.vstack(fp8_amax_dL_dY_tensor_list)

fp8_x_amax_history_stack = torch.vstack(fp8_x_amax_history_stack)
fp8_w_amax_history_stack = torch.vstack(fp8_w_amax_history_stack)
fp8_dL_dY_amax_history_stack = torch.vstack(fp8_dL_dY_amax_history_stack)

# Update the history stacks with the new amax values
_update_history_stack(fp8_amax_x_tensors, fp8_x_amax_history_stack)
_update_history_stack(fp8_amax_w_tensors, fp8_w_amax_history_stack)
_update_history_stack(fp8_amax_dL_dY_tensors, fp8_dL_dY_amax_history_stack)

# Calculate the new scales from the updated history stacks
new_x_scales = amax_history_to_scale_stack(
fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
)
new_w_scales = amax_history_to_scale_stack(
fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
)
new_dL_dY_scales = amax_history_to_scale_stack(
fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe
)
child.fp8_scale_x.copy_(new_x_scales[idx])
child.fp8_scale_w.copy_(new_w_scales[idx])
child.fp8_scale_dL_dY.copy_(new_dL_dY_scales[idx])

# Iterate through the layers and update the scales, and set the flag to signal that the amaxes/scales are ready
for idx, child in enumerate(fp8_layers):
child.fp8_scale_x.copy_(new_x_scales[idx])
child.fp8_scale_w.copy_(new_w_scales[idx])
child.fp8_scale_dL_dY.copy_(new_dL_dY_scales[idx])
# This allows for the compile to succede on the inner func and fail on the graph breaks
# at the beginning and and of syncing
inner_func()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a little hard to tell but I think this is all code motion? (moving it inside of inner_func())

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah tabs strike again here, but just indented one layer


for child in fp8_layers:
# 4. set a flag to signal amaxes/scales are ready
# We only update the flag if we know it will be checked by the modules
if fp8_config.enable_amax_init:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the thing where you mentioned that this config is actually needed for correctness

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was andrews original comment here: #220 (comment)

I used the wrong config setting

child.amax_and_scale_synced = True
child.amax_and_scale_synced = True
21 changes: 20 additions & 1 deletion test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@

import torch
import torch.nn as nn
from float8_experimental.float8_linear_utils import get_float8_linear, LinearType
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
get_float8_linear,
LinearType,
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_tensor import Float8Tensor

from torch._dynamo.test_case import TestCase as DynamoTestCase
Expand Down Expand Up @@ -199,5 +205,18 @@ def test_float8_graph_output(self):
)


@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA not available")
def test_sync_amax_func():
torch._dynamo.reset()
cnts = CompileCounterWithBackend("inductor")
module = torch.nn.Sequential(
nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True)
)
float8_mod = swap_linear_with_float8_linear(module, Float8Linear)
compiled_swap_func = torch.compile(sync_float8_amax_and_scale_history, backend=cnts)
compiled_swap_func(float8_mod)
assert cnts.frame_count == 1, "Compiled graph should have 1 frame!"


if __name__ == "__main__":
pytest.main([__file__])