-
Notifications
You must be signed in to change notification settings - Fork 19
[5/x] make FSDP2 with float8 all-gather work for Float8Linear #296
Conversation
Summary: Adds test coverage for `Float8Linear` with all dynamic scaling and FSDP2 with float8 all-gather. To make the tests pass, fixes a bug with initilization ordering in `Float8Linear.from_float`, we need to have the right forward config set before stashing it on the weight wrapper. Test Plan: ``` python test/test_fsdp2/test_fsdp2_eager.py /test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Adds test coverage for `Float8Linear` with all dynamic scaling and FSDP2 with float8 all-gather. To make the tests pass, fixes a bug with initilization ordering in `Float8Linear.from_float`, we need to have the right forward config set before stashing it on the weight wrapper. Test Plan: ``` python test/test_fsdp2/test_fsdp2_eager.py /test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: b6d6525 Pull Request resolved: #296
…ear" Summary: Adds test coverage for `Float8Linear` with all dynamic scaling and FSDP2 with float8 all-gather. To make the tests pass, fixes a bug with initilization ordering in `Float8Linear.from_float`, we need to have the right forward config set before stashing it on the weight wrapper. Test Plan: ``` python test/test_fsdp2/test_fsdp2_eager.py /test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: Adds test coverage for `Float8Linear` with all dynamic scaling and FSDP2 with float8 all-gather. To make the tests pass, fixes a bug with initilization ordering in `Float8Linear.from_float`, we need to have the right forward config set before stashing it on the weight wrapper. Test Plan: ``` python test/test_fsdp2/test_fsdp2_eager.py /test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 26b7138 Pull Request resolved: #296
) | ||
new_mod.weight = mod.weight | ||
else: | ||
assert not config.enable_fsdp_fp8_all_gather, "unsupported" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: maybe a more helpful assert message
@@ -74,8 +76,16 @@ def get_local_inp(self, dtype: torch.dtype = torch.float32): | |||
dist.broadcast(global_inp, src=0) | |||
return global_inp.view(self.world_size, -1)[self.rank].view(16, 16) | |||
|
|||
def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Module: | |||
return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs) | |||
def swap_linear_with_dynamic( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe losing some context but is there a reason why the existing swap function doesnt work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if the question is why do we need swap_linear_with_dynamic
, we probably don't. Removing that is not related to this PR though so I left it for a future person.
self._test_transformer_memory(enable_fsdp_fp8_all_gather) | ||
|
||
def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool): | ||
# for enable_fsdp_fp8_all_gather in [False, True]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can remove comment right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, maybe add a dummy test that float8Linear with not all dynamic errors when trying to use fp8 allgather
@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
This pull request has been merged in 412222b. |
Stack from ghstack (oldest at bottom):
Summary:
Adds test coverage for
Float8Linear
with all dynamic scaling and FSDP2with float8 all-gather.
To make the tests pass, fixes a bug with initilization ordering in
Float8Linear.from_float
, we need to have the right forward configset before stashing it on the weight wrapper.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: D59305793