Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit ea368d1

Browse files
committed
make sync have a compiled inner func
1 parent 9cce2b9 commit ea368d1

File tree

1 file changed

+96
-88
lines changed

1 file changed

+96
-88
lines changed

float8_experimental/float8_linear_utils.py

Lines changed: 96 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -185,102 +185,110 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
185185
)
186186
return
187187

188-
# Loop over all fp8 layers and grab the needed tensors
189-
fp8_amax_x_tensor_list = [None] * len(fp8_layers)
190-
fp8_amax_w_tensor_list = [None] * len(fp8_layers)
191-
fp8_amax_dL_dY_tensor_list = [None] * len(fp8_layers)
192-
193-
fp8_x_amax_history_stack = [None] * len(fp8_layers)
194-
fp8_w_amax_history_stack = [None] * len(fp8_layers)
195-
fp8_dL_dY_amax_history_stack = [None] * len(fp8_layers)
196-
197-
x_dtypes = set()
198-
scale_fn_recipes = set()
199-
200-
for idx, child in enumerate(fp8_layers):
201-
fp8_amax_x_tensor_list[idx] = child.fp8_amax_x
202-
fp8_amax_w_tensor_list[idx] = child.fp8_amax_w
203-
fp8_amax_dL_dY_tensor_list[idx] = child.fp8_amax_dL_dY
204-
205-
fp8_x_amax_history_stack[idx] = child.fp8_amax_history_x
206-
fp8_w_amax_history_stack[idx] = child.fp8_amax_history_w
207-
fp8_dL_dY_amax_history_stack[idx] = child.fp8_amax_history_dL_dY
208-
209-
x_dtypes.add(child.last_seen_input_dtype)
210-
scale_fn_recipes.add(child.recipe.scale_fn_name)
211-
212-
# TODO This way to get the activation dtype is not ideal
213-
if len(x_dtypes) != 1:
214-
raise ValueError(
215-
f"All layers must have the same last seen input_dtype, got {x_dtypes}"
216-
)
217-
x_dtype = next(iter(x_dtypes))
188+
def inner_func():
189+
# Loop over all fp8 layers and grab the needed tensors
190+
fp8_amax_x_tensor_list = [None] * len(fp8_layers)
191+
fp8_amax_w_tensor_list = [None] * len(fp8_layers)
192+
fp8_amax_dL_dY_tensor_list = [None] * len(fp8_layers)
218193

219-
if len(scale_fn_recipes) != 1:
220-
raise ValueError(
221-
f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}"
222-
)
223-
scale_fn_recipe = next(iter(scale_fn_recipes))
194+
fp8_x_amax_history_stack = [None] * len(fp8_layers)
195+
fp8_w_amax_history_stack = [None] * len(fp8_layers)
196+
fp8_dL_dY_amax_history_stack = [None] * len(fp8_layers)
224197

225-
assert (
226-
len(fp8_amax_x_tensor_list)
227-
== len(fp8_amax_w_tensor_list)
228-
== len(fp8_amax_dL_dY_tensor_list)
229-
), "Mismatched lengths of amax tensors."
230-
231-
if dist.is_initialized():
232-
# Combine all the amax tensors into one tensor and reduce it
233-
all_amax_tensors = torch.cat(
234-
fp8_amax_x_tensor_list + fp8_amax_w_tensor_list + fp8_amax_dL_dY_tensor_list
198+
x_dtypes = set()
199+
scale_fn_recipes = set()
200+
201+
for idx, child in enumerate(fp8_layers):
202+
fp8_amax_x_tensor_list[idx] = child.fp8_amax_x
203+
fp8_amax_w_tensor_list[idx] = child.fp8_amax_w
204+
fp8_amax_dL_dY_tensor_list[idx] = child.fp8_amax_dL_dY
205+
206+
fp8_x_amax_history_stack[idx] = child.fp8_amax_history_x
207+
fp8_w_amax_history_stack[idx] = child.fp8_amax_history_w
208+
fp8_dL_dY_amax_history_stack[idx] = child.fp8_amax_history_dL_dY
209+
210+
x_dtypes.add(child.last_seen_input_dtype)
211+
scale_fn_recipes.add(child.recipe.scale_fn_name)
212+
213+
# TODO This way to get the activation dtype is not ideal
214+
if len(x_dtypes) != 1:
215+
raise ValueError(
216+
f"All layers must have the same last seen input_dtype, got {x_dtypes}"
217+
)
218+
x_dtype = next(iter(x_dtypes))
219+
220+
if len(scale_fn_recipes) != 1:
221+
raise ValueError(
222+
f"All layers must have the same scale_fn recipe, got {scale_fn_recipes}"
223+
)
224+
scale_fn_recipe = next(iter(scale_fn_recipes))
225+
226+
assert (
227+
len(fp8_amax_x_tensor_list)
228+
== len(fp8_amax_w_tensor_list)
229+
== len(fp8_amax_dL_dY_tensor_list)
230+
), "Mismatched lengths of amax tensors."
231+
232+
if dist.is_initialized():
233+
# Combine all the amax tensors into one tensor and reduce it
234+
all_amax_tensors = torch.cat(
235+
fp8_amax_x_tensor_list
236+
+ fp8_amax_w_tensor_list
237+
+ fp8_amax_dL_dY_tensor_list
238+
)
239+
all_reduced_amax_tensor = all_reduce(
240+
all_amax_tensors, "MAX", list(range(dist.get_world_size()))
241+
)
242+
if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor):
243+
all_reduced_amax_tensor = all_reduced_amax_tensor.wait()
244+
245+
(
246+
reduced_fp8_amax_tensor,
247+
reduced_fp8_amax_w_tensor,
248+
reduced_fp8_amax_dL_dY_tensor,
249+
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list))
250+
251+
for idx, child in enumerate(fp8_layers):
252+
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx])
253+
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx])
254+
child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx])
255+
256+
# We create two stacked tensor groups, one for the amax history and one for the current scales
257+
fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list)
258+
fp8_amax_w_tensors = torch.vstack(fp8_amax_w_tensor_list)
259+
fp8_amax_dL_dY_tensors = torch.vstack(fp8_amax_dL_dY_tensor_list)
260+
261+
fp8_x_amax_history_stack = torch.vstack(fp8_x_amax_history_stack)
262+
fp8_w_amax_history_stack = torch.vstack(fp8_w_amax_history_stack)
263+
fp8_dL_dY_amax_history_stack = torch.vstack(fp8_dL_dY_amax_history_stack)
264+
265+
# Update the history stacks with the new amax values
266+
_update_history_stack(fp8_amax_x_tensors, fp8_x_amax_history_stack)
267+
_update_history_stack(fp8_amax_w_tensors, fp8_w_amax_history_stack)
268+
_update_history_stack(fp8_amax_dL_dY_tensors, fp8_dL_dY_amax_history_stack)
269+
270+
# Calculate the new scales from the updated history stacks
271+
new_x_scales = amax_history_to_scale_stack(
272+
fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
235273
)
236-
all_reduced_amax_tensor = all_reduce(
237-
all_amax_tensors, "MAX", list(range(dist.get_world_size()))
274+
new_w_scales = amax_history_to_scale_stack(
275+
fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
276+
)
277+
new_dL_dY_scales = amax_history_to_scale_stack(
278+
fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe
238279
)
239-
if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor):
240-
all_reduced_amax_tensor = all_reduced_amax_tensor.wait()
241-
242-
(
243-
reduced_fp8_amax_tensor,
244-
reduced_fp8_amax_w_tensor,
245-
reduced_fp8_amax_dL_dY_tensor,
246-
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list))
247280

281+
# Iterate through the layers and update the scales, and set the flag to signal that the amaxes/scales are ready
248282
for idx, child in enumerate(fp8_layers):
249-
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx])
250-
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx])
251-
child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx])
252-
253-
# We create two stacked tensor groups, one for the amax history and one for the current scales
254-
fp8_amax_x_tensors = torch.vstack(fp8_amax_x_tensor_list)
255-
fp8_amax_w_tensors = torch.vstack(fp8_amax_w_tensor_list)
256-
fp8_amax_dL_dY_tensors = torch.vstack(fp8_amax_dL_dY_tensor_list)
257-
258-
fp8_x_amax_history_stack = torch.vstack(fp8_x_amax_history_stack)
259-
fp8_w_amax_history_stack = torch.vstack(fp8_w_amax_history_stack)
260-
fp8_dL_dY_amax_history_stack = torch.vstack(fp8_dL_dY_amax_history_stack)
261-
262-
# Update the history stacks with the new amax values
263-
_update_history_stack(fp8_amax_x_tensors, fp8_x_amax_history_stack)
264-
_update_history_stack(fp8_amax_w_tensors, fp8_w_amax_history_stack)
265-
_update_history_stack(fp8_amax_dL_dY_tensors, fp8_dL_dY_amax_history_stack)
266-
267-
# Calculate the new scales from the updated history stacks
268-
new_x_scales = amax_history_to_scale_stack(
269-
fp8_x_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
270-
)
271-
new_w_scales = amax_history_to_scale_stack(
272-
fp8_w_amax_history_stack, torch.float8_e4m3fn, x_dtype, scale_fn_recipe
273-
)
274-
new_dL_dY_scales = amax_history_to_scale_stack(
275-
fp8_dL_dY_amax_history_stack, torch.float8_e5m2, x_dtype, scale_fn_recipe
276-
)
283+
child.fp8_scale_x.copy_(new_x_scales[idx])
284+
child.fp8_scale_w.copy_(new_w_scales[idx])
285+
child.fp8_scale_dL_dY.copy_(new_dL_dY_scales[idx])
277286

278-
# Iterate through the layers and update the scales, and set the flag to signal that the amaxes/scales are ready
279-
for idx, child in enumerate(fp8_layers):
280-
child.fp8_scale_x.copy_(new_x_scales[idx])
281-
child.fp8_scale_w.copy_(new_w_scales[idx])
282-
child.fp8_scale_dL_dY.copy_(new_dL_dY_scales[idx])
287+
# When cuda-graphs work we should compile with "reduce-overhead"
288+
compiled_inner_func = torch.compile(inner_func)
289+
compiled_inner_func()
283290

291+
for child in fp8_layers:
284292
# 4. set a flag to signal amaxes/scales are ready
285293
# We only update the flag if we know it will be checked by the modules
286294
if fp8_config.enable_amax_init:

0 commit comments

Comments
 (0)