Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 746519f

Browse files
author
Andrew Gu
committed
Used functional all-reduce for amax reduction
ghstack-source-id: 8bbd5fa Pull Request resolved: #219
1 parent 956195b commit 746519f

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

float8_experimental/float8_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,12 @@ def tensor_to_amax(x, distributed_reduction=False):
7575
# If the user did not ask for it, assume that it will
7676
# happen elsewhere.
7777
if distributed_reduction and dist.is_initialized():
78-
dist.all_reduce(amax, op=dist.ReduceOp.MAX)
78+
# TODO: Dynamo rewriting synchronous in-place collectives does not work
79+
# at the moment. Use functional all-reduce to avoid graph break.
80+
amax = dist._functional_collectives.all_reduce(
81+
amax, "MAX", list(range(dist.get_world_size()))
82+
)
83+
# dist.all_reduce(amax, op=dist.ReduceOp.MAX)
7984

8085
return amax
8186

0 commit comments

Comments
 (0)