Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 336ef68

Browse files
committed
use one reduce intead of 3
1 parent 2beea99 commit 336ef68

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

float8_experimental/float8_linear_utils.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -207,15 +207,17 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
207207

208208
if dist.is_initialized():
209209
# Combine all the amax tensors into one tensor and reduce it
210-
fp8_amax_x_tensor = torch.cat(fp8_amax_x_tensor_list)
211-
fp8_amax_w_tensor = torch.cat(fp8_amax_w_tensor_list)
212-
fp8_amax_dL_dY_tensor = torch.cat(fp8_amax_dL_dY_tensor_list)
213-
214-
reduced_fp8_amax_tensor = all_reduce(fp8_amax_x_tensor, "MAX", _get_default_group())
215-
reduced_fp8_amax_w_tensor = all_reduce(fp8_amax_w_tensor, "MAX", _get_default_group())
216-
reduced_fp8_amax_dL_dY_tensor = all_reduce(fp8_amax_dL_dY_tensor, "MAX", _get_default_group())
217-
218-
# Reassign the reduced amax values to the original tensors
210+
all_amax_tensors = torch.cat(
211+
fp8_amax_x_tensor_list + fp8_amax_w_tensor_list + fp8_amax_dL_dY_tensor_list
212+
)
213+
all_reduced_amax_tensor = all_reduce(
214+
all_amax_tensors, "MAX", _get_default_group()
215+
)
216+
(
217+
reduced_fp8_amax_tensor,
218+
reduced_fp8_amax_w_tensor,
219+
reduced_fp8_amax_dL_dY_tensor,
220+
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list))
219221
for idx, child in enumerate(fp8_layers):
220222
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx])
221223
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx])

0 commit comments

Comments
 (0)