Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 026023f

Browse files
committed
use functional tensor
1 parent 29c059b commit 026023f

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

float8_experimental/float8_linear_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
from float8_experimental.float8_linear import Float8Linear
1616

1717
from float8_experimental.float8_utils import amax_history_to_scale_stack
18-
from torch.distributed._functional_collectives import all_reduce
19-
from torch.distributed.distributed_c10d import _get_default_group
18+
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor
2019

2120
log = logging.getLogger(__name__)
2221
log.addHandler(logging.NullHandler())
@@ -211,8 +210,11 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
211210
fp8_amax_x_tensor_list + fp8_amax_w_tensor_list + fp8_amax_dL_dY_tensor_list
212211
)
213212
all_reduced_amax_tensor = all_reduce(
214-
all_amax_tensors, "MAX", _get_default_group()
213+
all_amax_tensors, "MAX", list(range(dist.get_world_size()))
215214
)
215+
if isinstance(all_reduced_amax_tensor, AsyncCollectiveTensor):
216+
all_reduced_amax_tensor = all_reduced_amax_tensor.wait()
217+
216218
(
217219
reduced_fp8_amax_tensor,
218220
reduced_fp8_amax_w_tensor,

0 commit comments

Comments
 (0)