Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit b508920

Browse files
drisspgfacebook-github-bot
authored andcommitted
Add a compiled inner func to sync amac (#221)
Summary: See this comment: #220 (comment) I still need to verify that this works, I also want to add a test for when we compile the outer func.. not sure what it happends cc awgu, bdhirsh Pull Request resolved: #221 Reviewed By: bdhirsh Differential Revision: D54074075 Pulled By: drisspg fbshipit-source-id: 185dd60d39e866122f55cef78d1fba9b475088a4
1 parent 9cce2b9 commit b508920

File tree

2 files changed

+129
-91
lines changed

2 files changed

+129
-91
lines changed

float8_experimental/float8_linear_utils.py

Lines changed: 109 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -185,103 +185,122 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
185185
)
186186
return
187187

188-
# Loop over all fp8 layers and grab the needed tensors
189-
fp8_amax_x_tensor_list = [None] * len(fp8_layers)
190-
fp8_amax_w_tensor_list = [None] * len(fp8_layers)
191-
fp8_amax_dL_dY_tensor_list = [None] * len(fp8_layers)
192-
193-
fp8_x_amax_history_stack = [None] * len(fp8_layers)
194-
fp8_w_amax_history_stack = [None] * len(fp8_layers)
195-
fp8_dL_dY_amax_history_stack = [None] * len(fp8_layers)
196-
197-
x_dtypes = set()
198-
scale_fn_recipes = set()
199-
200-
for idx, child in enumerate(fp8_layers):
201-
fp8_amax_x_tensor_list[idx] = child.fp8_amax_x
202-
fp8_amax_w_tensor_list[idx] = child.fp8_amax_w
203-
fp8_amax_dL_dY_tensor_list[idx] = child.fp8_amax_dL_dY
204-
205-
fp8_x_amax_history_stack[idx] = child.fp8_amax_history_x
206-
fp8_w_amax_history_stack[idx] = child.fp8_amax_history_w
207-
fp8_dL_dY_amax_history_stack[idx] = child.fp8_amax_history_dL_dY
208-
209-
x_dtypes.add(child.last_seen_input_dtype)
210-
scale_fn_recipes.add(child.recipe.scale_fn_name)
211-
212-
# TODO This way to get the activation dtype is not ideal
213-
if len(x_dtypes) != 1:
214-
raise ValueError(
215-
f"All layers must have the same last seen input_dtype, got {x_dtypes}"
216-
)
217-
x_dtype = next(iter(x_dtypes))
188+
def inner_func():
189+
"""Why do we have this inner_function?
190+
191+
There are two portions of the outer sync_function that cause graph_breaks:
192+
1. The `get_float8_layers` call can cause graph breaks if the user did not pass
193+
in the fp8_layers.
194+
2. At the end of syncing all the amaxes and scales we set the attr on the module
195+
signaling that we have synced the amaxes and scales and the next forward can be run.
196+
# TODO Maybe we should remove this safety check to remove the graph break?
197+
198+
By having this inner function, we can ensure that although the outer function may cause graph breaks
199+
the inner function will not.
200+
"""
201+
# Loop over all fp8 layers and grab the needed tensors
202+
fp8_amax_x_tensor_list = [None] * len(fp8_layers)
203+
fp8_amax_w_tensor_list = [None] * len(fp8_layers)
204+
fp8_amax_dL_dY_tensor_list = [None] * len(fp8_layers)
205+
206+
fp8_x_amax_history_stack = [None] * len(fp8_layers)
207+
fp8_w_amax_history_stack = [None] * len(fp8_layers)
208+
fp8_dL_dY_amax_history_stack = [None] * len(fp8_layers)
209+
210+
x_dtypes = set()
211+
scale_fn_recipes = set()
218212

219-
if len(scale_fn_recipes) != 1:
220-
raise ValueError(
221-
f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}"
222-
)
223-
scale_fn_recipe = next(iter(scale_fn_recipes))
213+
for idx, child in enumerate(fp8_layers):
214+
fp8_amax_x_tensor_list[idx] = child.fp8_amax_x
215+
fp8_amax_w_tensor_list[idx] = child.fp8_amax_w
216+
fp8_amax_dL_dY_tensor_list[idx] = child.fp8_amax_dL_dY
224217

225-
assert (
226-
len(fp8_amax_x_tensor_list)
227-
== len(fp8_amax_w_tensor_list)
228-
== len(fp8_amax_dL_dY_tensor_list)
229-
), "Mismatched lengths of amax tensors."
230-
231-
if dist.is_initialized():
232-
# Combine all the amax tensors into one tensor and reduce it
233-
all_amax_tensors = torch.cat(
234-
fp8_amax_x_tensor_list + fp8_amax_w_tensor_list + fp8_amax_dL_dY_tensor_list
218+
fp8_x_amax_history_stack[idx] = child.fp8_amax_history_x
219+
fp8_w_amax_history_stack[idx] = child.fp8_amax_history_w
220+
fp8_dL_dY_amax_history_stack[idx] = child.fp8_amax_history_dL_dY
221+
222+
x_dtypes.add(child.last_seen_input_dtype)
223+
scale_fn_recipes.add(child.recipe.scale_fn_name)
224+
225+
# TODO This way to get the activation dtype is not ideal
226+
if len(x_dtypes) != 1:
227+
raise ValueError(
228+
f"All layers must have the same last seen input_dtype, got {x_dtypes}"
229+
)
230+
x_dtype = next(iter(x_dtypes))
231+
232+
if len(scale_fn_recipes) != 1:
233+
raise ValueError(
234+
f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}"
235+
)
236+
scale_fn_recipe = next(iter(scale_fn_recipes))
237+
238+
assert (
239+
len(fp8_amax_x_tensor_list)
240+
== len(fp8_amax_w_tensor_list)
241+
== len(fp8_amax_dL_dY_tensor_list)
242+
), "Mismatched lengths of amax tensors."
243+
244+
if dist.is_initialized():
245+
# Combine all the amax tensors into one tensor and reduce it
246+
all_amax_tensors = torch.cat(
247+
fp8_amax_x_tensor_list
248+
+ fp8_amax_w_tensor_list
249+
+ fp8_amax_dL_dY_tensor_list
250+
)
251+
all_reduced_amax_tensor = all_reduce(
252+
all_amax_tensors, "MAX", list(range(dist.get_world_size()))
253+
)
254+
if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor):
255+
all_reduced_amax_tensor = all_reduced_amax_tensor.wait()
256+
257+
(
258+
reduced_fp8_amax_tensor,
259+
reduced_fp8_amax_w_tensor,
260+
reduced_fp8_amax_dL_dY_tensor,
261+
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list))
262+
263+
for idx, child in enumerate(fp8_layers):
264+
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx])
265+
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx])
266+
child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx])
267+
268+
# We create two stacked tensor groups, one for the amax history and one for the current scales
269+
fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list)
270+
fp8_amax_w_tensors = torch.vstack(fp8_amax_w_tensor_list)
271+
fp8_amax_dL_dY_tensors = torch.vstack(fp8_amax_dL_dY_tensor_list)
272+
273+
fp8_x_amax_history_stack = torch.vstack(fp8_x_amax_history_stack)
274+
fp8_w_amax_history_stack = torch.vstack(fp8_w_amax_history_stack)
275+
fp8_dL_dY_amax_history_stack = torch.vstack(fp8_dL_dY_amax_history_stack)
276+
277+
# Update the history stacks with the new amax values
278+
_update_history_stack(fp8_amax_x_tensors, fp8_x_amax_history_stack)
279+
_update_history_stack(fp8_amax_w_tensors, fp8_w_amax_history_stack)
280+
_update_history_stack(fp8_amax_dL_dY_tensors, fp8_dL_dY_amax_history_stack)
281+
282+
# Calculate the new scales from the updated history stacks
283+
new_x_scales = amax_history_to_scale_stack(
284+
fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
235285
)
236-
all_reduced_amax_tensor = all_reduce(
237-
all_amax_tensors, "MAX", list(range(dist.get_world_size()))
286+
new_w_scales = amax_history_to_scale_stack(
287+
fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
288+
)
289+
new_dL_dY_scales = amax_history_to_scale_stack(
290+
fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe
238291
)
239-
if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor):
240-
all_reduced_amax_tensor = all_reduced_amax_tensor.wait()
241-
242-
(
243-
reduced_fp8_amax_tensor,
244-
reduced_fp8_amax_w_tensor,
245-
reduced_fp8_amax_dL_dY_tensor,
246-
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list))
247292

293+
# Iterate through the layers and update the scales, and set the flag to signal that the amaxes/scales are ready
248294
for idx, child in enumerate(fp8_layers):
249-
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx])
250-
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx])
251-
child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx])
252-
253-
# We create two stacked tensor groups, one for the amax history and one for the current scales
254-
fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list)
255-
fp8_amax_w_tensors = torch.vstack(fp8_amax_w_tensor_list)
256-
fp8_amax_dL_dY_tensors = torch.vstack(fp8_amax_dL_dY_tensor_list)
257-
258-
fp8_x_amax_history_stack = torch.vstack(fp8_x_amax_history_stack)
259-
fp8_w_amax_history_stack = torch.vstack(fp8_w_amax_history_stack)
260-
fp8_dL_dY_amax_history_stack = torch.vstack(fp8_dL_dY_amax_history_stack)
261-
262-
# Update the history stacks with the new amax values
263-
_update_history_stack(fp8_amax_x_tensors, fp8_x_amax_history_stack)
264-
_update_history_stack(fp8_amax_w_tensors, fp8_w_amax_history_stack)
265-
_update_history_stack(fp8_amax_dL_dY_tensors, fp8_dL_dY_amax_history_stack)
266-
267-
# Calculate the new scales from the updated history stacks
268-
new_x_scales = amax_history_to_scale_stack(
269-
fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
270-
)
271-
new_w_scales = amax_history_to_scale_stack(
272-
fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
273-
)
274-
new_dL_dY_scales = amax_history_to_scale_stack(
275-
fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe
276-
)
295+
child.fp8_scale_x.copy_(new_x_scales[idx])
296+
child.fp8_scale_w.copy_(new_w_scales[idx])
297+
child.fp8_scale_dL_dY.copy_(new_dL_dY_scales[idx])
277298

278-
# Iterate through the layers and update the scales, and set the flag to signal that the amaxes/scales are ready
279-
for idx, child in enumerate(fp8_layers):
280-
child.fp8_scale_x.copy_(new_x_scales[idx])
281-
child.fp8_scale_w.copy_(new_w_scales[idx])
282-
child.fp8_scale_dL_dY.copy_(new_dL_dY_scales[idx])
299+
# This allows for the compile to succede on the inner func and fail on the graph breaks
300+
# at the beginning and and of syncing
301+
inner_func()
283302

303+
for child in fp8_layers:
284304
# 4. set a flag to signal amaxes/scales are ready
285305
# We only update the flag if we know it will be checked by the modules
286-
if fp8_config.enable_amax_init:
287-
child.amax_and_scale_synced = True
306+
child.amax_and_scale_synced = True

test/test_compile.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@
1111

1212
import torch
1313
import torch.nn as nn
14-
from float8_experimental.float8_linear_utils import get_float8_linear, LinearType
14+
from float8_experimental.float8_linear import Float8Linear
15+
from float8_experimental.float8_linear_utils import (
16+
get_float8_linear,
17+
LinearType,
18+
swap_linear_with_float8_linear,
19+
sync_float8_amax_and_scale_history,
20+
)
1521
from float8_experimental.float8_tensor import Float8Tensor
1622

1723
from torch._dynamo.test_case import TestCase as DynamoTestCase
@@ -199,5 +205,18 @@ def test_float8_graph_output(self):
199205
)
200206

201207

208+
@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA not available")
209+
def test_sync_amax_func():
210+
torch._dynamo.reset()
211+
cnts = CompileCounterWithBackend("inductor")
212+
module = torch.nn.Sequential(
213+
nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True)
214+
)
215+
float8_mod = swap_linear_with_float8_linear(module, Float8Linear)
216+
compiled_swap_func = torch.compile(sync_float8_amax_and_scale_history, backend=cnts)
217+
compiled_swap_func(float8_mod)
218+
assert cnts.frame_count == 1, "Compiled graph should have 1 frame!"
219+
220+
202221
if __name__ == "__main__":
203222
pytest.main([__file__])

0 commit comments

Comments
 (0)