-
Notifications
You must be signed in to change notification settings - Fork 19
Used functional all-reduce for amax reduction #219
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
Dynamo cannot remap the in-place max all-reduce to its functional equivalent on PyTorch `main` branch. pytorch/pytorch#120082 [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
I will hold off on landing this since @yifuwang will land the Dynamo rewrite fix soon. |
# https://github.com/pytorch/pytorch/issues/120082 | ||
# Use functional all-reduce to avoid graph breaking. | ||
amax = dist._functional_collectives.all_reduce( | ||
amax, "MAX", list(range(dist.get_world_size())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ranks + tag as the process group identifier has been deprecated. Can we pass dist.group.WORLD
or dist.group.WORLD.group_name
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are not landing this PR, then is it okay to leave the call as dist.all_reduce(amax, op=dist.ReduceOp.MAX)
and wait for your Dynamo rewrite changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh the changes you are referring to are for rewriting functional collective. The options I mentioned above should already work :)
Let me know if it doesn't though.
Stack from ghstack (oldest at bottom):
module
arg fromfsdp_pre_all_gather
#217use_activation_hooks: bool
to swap #214amax_and_scale_synced
unconditionally #220Dynamo cannot remap the in-place max all-reduce to its functional equivalent on PyTorch
main
branch.pytorch/pytorch#120082