Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

[5/x] make FSDP2 with float8 all-gather work for Float8Linear #296

Closed
wants to merge 2 commits into from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Jul 2, 2024

Stack from ghstack (oldest at bottom):

Summary:

Adds test coverage for Float8Linear with all dynamic scaling and FSDP2
with float8 all-gather.

To make the tests pass, fixes a bug with initilization ordering in
Float8Linear.from_float, we need to have the right forward config
set before stashing it on the weight wrapper.

Test Plan:

python test/test_fsdp2/test_fsdp2_eager.py
/test/test_everything.sh

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: D59305793

Summary:

Adds test coverage for `Float8Linear` with all dynamic scaling and FSDP2
with float8 all-gather.

To make the tests pass, fixes a bug with initilization ordering in
`Float8Linear.from_float`, we need to have the right forward config
set before stashing it on the weight wrapper.

Test Plan:

```
python test/test_fsdp2/test_fsdp2_eager.py
/test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 2, 2024
vkuzo added a commit that referenced this pull request Jul 2, 2024
Summary:

Adds test coverage for `Float8Linear` with all dynamic scaling and FSDP2
with float8 all-gather.

To make the tests pass, fixes a bug with initilization ordering in
`Float8Linear.from_float`, we need to have the right forward config
set before stashing it on the weight wrapper.

Test Plan:

```
python test/test_fsdp2/test_fsdp2_eager.py
/test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: b6d6525
Pull Request resolved: #296
…ear"

Summary:

Adds test coverage for `Float8Linear` with all dynamic scaling and FSDP2
with float8 all-gather.

To make the tests pass, fixes a bug with initilization ordering in
`Float8Linear.from_float`, we need to have the right forward config
set before stashing it on the weight wrapper.

Test Plan:

```
python test/test_fsdp2/test_fsdp2_eager.py
/test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 2, 2024
Summary:

Adds test coverage for `Float8Linear` with all dynamic scaling and FSDP2
with float8 all-gather.

To make the tests pass, fixes a bug with initilization ordering in
`Float8Linear.from_float`, we need to have the right forward config
set before stashing it on the weight wrapper.

Test Plan:

```
python test/test_fsdp2/test_fsdp2_eager.py
/test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 26b7138
Pull Request resolved: #296
@vkuzo vkuzo requested review from drisspg, awgu and weifengpy July 2, 2024 20:34
)
new_mod.weight = mod.weight
else:
assert not config.enable_fsdp_fp8_all_gather, "unsupported"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: maybe a more helpful assert message

@@ -74,8 +76,16 @@ def get_local_inp(self, dtype: torch.dtype = torch.float32):
dist.broadcast(global_inp, src=0)
return global_inp.view(self.world_size, -1)[self.rank].view(16, 16)

def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Module:
return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs)
def swap_linear_with_dynamic(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe losing some context but is there a reason why the existing swap function doesnt work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the question is why do we need swap_linear_with_dynamic, we probably don't. Removing that is not related to this PR though so I left it for a future person.

self._test_transformer_memory(enable_fsdp_fp8_all_gather)

def _test_transformer_memory(self, enable_fsdp_fp8_all_gather: bool):
# for enable_fsdp_fp8_all_gather in [False, True]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove comment right?

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, maybe add a dummy test that float8Linear with not all dynamic errors when trying to use fp8 allgather

@vkuzo
Copy link
Contributor Author

vkuzo commented Jul 2, 2024

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 412222b.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants