Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Removed module arg from fsdp_pre_all_gather #217

Closed
wants to merge 7 commits into from

Conversation

awgu
Copy link

@awgu awgu commented Feb 15, 2024

Stack from ghstack (oldest at bottom):

On the discussion of whether fp8 all-gather is viable without compiling the pre-all-gather cast to fp8, this PR would add more CPU overhead due to the __torch_function__ override, making it less viable.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 15, 2024
awgu pushed a commit that referenced this pull request Feb 15, 2024
ghstack-source-id: d3c0a7b
Pull Request resolved: #217
@@ -232,7 +232,7 @@ def cast_w_to_float8(
self.fp8_scale_w,
scale_fn_name,
torch.float8_e4m3fn,
is_amax_initialized,
self.is_amax_initialized,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to push this is_amax_initialized into the cast function since the Float8LinearWeightTensor subclass cannot keep a reference to the bool.

return o

with torch._C.DisableTorchFunctionSubclass():
if isinstance(args[0], cls):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The heuristic here is just to propagate the Float8LinearWeightTensor as long as it is the first argument.

On the discussion of whether fp8 all-gather is viable _without_ compiling the pre-all-gather cast to fp8, this PR would add _more_ CPU overhead due to the `__torch_function__` override, making it less viable.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Feb 15, 2024
ghstack-source-id: aa0fe4a
Pull Request resolved: #217
@@ -142,12 +147,34 @@ def from_float(


class Float8DynamicLinearWeightTensor(torch.Tensor):
# TODO: Remove `module` arg, save state on subclass, and propagate it.
def fsdp_pre_all_gather(
self, module: nn.Module
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem statement is that in order to implement the pre-all-gather transform, the subclass needs some state (e.g. mainly emulate but preferably the cast function as well). In the first prototype, I shortcutted by passing module into fsdp_pre_all_gather() so that the transform could read state off the module instead of storing it on the subclass itself.

However, the morally right thing (from a design perspective) should be to put that state on the subclass. I wonder though, how does this kind of __torch_function__ subclass interact with torch.compile today. cc: @bdhirsh

On the discussion of whether fp8 all-gather is viable _without_ compiling the pre-all-gather cast to fp8, this PR would add _more_ CPU overhead due to the `__torch_function__` override, making it less viable.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Feb 15, 2024
ghstack-source-id: fc5a6df
Pull Request resolved: #217
On the discussion of whether fp8 all-gather is viable _without_ compiling the pre-all-gather cast to fp8, this PR would add _more_ CPU overhead due to the `__torch_function__` override, making it less viable.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Feb 16, 2024
ghstack-source-id: 7b34f4e
Pull Request resolved: #217
On the discussion of whether fp8 all-gather is viable _without_ compiling the pre-all-gather cast to fp8, this PR would add _more_ CPU overhead due to the `__torch_function__` override, making it less viable.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Feb 16, 2024
ghstack-source-id: bb1e4c0
Pull Request resolved: #217
On the discussion of whether fp8 all-gather is viable _without_ compiling the pre-all-gather cast to fp8, this PR would add _more_ CPU overhead due to the `__torch_function__` override, making it less viable.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Feb 16, 2024
ghstack-source-id: bb1e4c0
Pull Request resolved: #217
On the discussion of whether fp8 all-gather is viable _without_ compiling the pre-all-gather cast to fp8, this PR would add _more_ CPU overhead due to the `__torch_function__` override, making it less viable.

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Feb 16, 2024
ghstack-source-id: 9bfc02a
Pull Request resolved: #217
@awgu awgu closed this Feb 27, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants