Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Used functional all-reduce for amax reduction #219

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,13 @@ def tensor_to_amax(x, distributed_reduction=False):
# If the user did not ask for it, assume that it will
# happen elsewhere.
if distributed_reduction and dist.is_initialized():
dist.all_reduce(amax, op=dist.ReduceOp.MAX)
# TODO: Dynamo rewriting synchronous in-place collectives fails:
# https://github.com/pytorch/pytorch/issues/120082
# Use functional all-reduce to avoid graph breaking.
amax = dist._functional_collectives.all_reduce(
amax, "MAX", list(range(dist.get_world_size()))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ranks + tag as the process group identifier has been deprecated. Can we pass dist.group.WORLD or dist.group.WORLD.group_name here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are not landing this PR, then is it okay to leave the call as dist.all_reduce(amax, op=dist.ReduceOp.MAX) and wait for your Dynamo rewrite changes?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh the changes you are referring to are for rewriting functional collective. The options I mentioned above should already work :)

Let me know if it doesn't though.

)
# dist.all_reduce(amax, op=dist.ReduceOp.MAX)

return amax

Expand Down