Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 3f85814

Browse files
author
Andrew Gu
committed
Used functional all-reduce for amax reduction
ghstack-source-id: 256aebe Pull Request resolved: #219
1 parent 7032367 commit 3f85814

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

float8_experimental/float8_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,13 @@ def tensor_to_amax(x, distributed_reduction=False):
7575
# If the user did not ask for it, assume that it will
7676
# happen elsewhere.
7777
if distributed_reduction and dist.is_initialized():
78-
dist.all_reduce(amax, op=dist.ReduceOp.MAX)
78+
# TODO: Dynamo rewriting synchronous in-place collectives fails:
79+
# https://github.com/pytorch/pytorch/issues/120082
80+
# Use functional all-reduce to avoid graph breaking.
81+
amax = dist._functional_collectives.all_reduce(
82+
amax, "MAX", list(range(dist.get_world_size()))
83+
)
84+
# dist.all_reduce(amax, op=dist.ReduceOp.MAX)
7985

8086
return amax
8187

0 commit comments

Comments
 (0)