Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

support delayed scaling of weight in float8 all-gather #312

Closed
wants to merge 3 commits into from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Jul 9, 2024

Stack from ghstack (oldest at bottom):

Summary:

Adds support for delayed scaling in FSDP2 float8 all-gather. In detail:

  1. add WeightWithDelayedFloat8CastTensor, note that we don't reuse
    code with the dynamic version because I'd rather not deal with
    plumbing optional tensors through dynamo. We can try that in a
    separate PR later.
  2. wire Float8Linear to use (1)
  3. add weight amax syncing back, since we need it for float8 all-gather
  4. add test coverage for eager mode numerics

Next up (in separate PRs) will be training run validation for numerics, and
taking a look at performance.

Test Plan:

./test/test_everything.sh

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: D59685258

Summary:

Adds support for delayed scaling in FSDP2 float8 all-gather. In detail:
1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse
   code with the dynamic version because I'd rather not deal with
   plumbing optional tensors through dynamo. We can try that in a
   separate PR later.
2. wire `Float8Linear` to use (1)
3. add weight amax syncing back, since we need it for float8 all-gather
4. add test coverage for eager mode numerics

Next up (in separate PRs) will be training run validation for numerics, and
taking a look at performance.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 9, 2024
Summary:

Adds support for delayed scaling in FSDP2 float8 all-gather. In detail:
1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse
   code with the dynamic version because I'd rather not deal with
   plumbing optional tensors through dynamo. We can try that in a
   separate PR later.
2. wire `Float8Linear` to use (1)
3. add weight amax syncing back, since we need it for float8 all-gather
4. add test coverage for eager mode numerics

Next up (in separate PRs) will be training run validation for numerics, and
taking a look at performance.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: f1707c1
Pull Request resolved: #312
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 9, 2024
@vkuzo vkuzo requested review from awgu, drisspg, weifengpy and bdhirsh July 9, 2024 20:37
@weifengpy
Copy link
Contributor

I'd rather not deal with
plumbing optional tensors through dynamo

what are the optional tensors ?

all_amax_tensors = torch.cat(
fp8_amax_x_tensor_list + fp8_amax_dL_dY_tensor_list
fp8_amax_x_tensor_list
+ fp8_amax_w_tensor_list
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we only do this if we are using fp8 all gather ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that could make sense, I'd love to see the data to see if this is going to matter for performance. Focusing on numerics for now, was hoping for performance be tackled in future PRs.

@@ -110,3 +112,181 @@ def fsdp_post_all_gather(
out._scale = scale
return
return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,)


class WeightWithDelayedFloat8CastTensor(torch.Tensor):
Copy link
Contributor

@drisspg drisspg Jul 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[no change needed] I wish there was a way to share some more code with the dynamic version

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, me too. Looking at the code below, really the only code which would be shared is fsdp_post_all_gather, everything else would have to have if/else branches for delayed vs dynamic

def __repr__(self):
return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._mm_config})"

def fsdp_pre_all_gather(self, mesh):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ill let @weifengpy confirm this portion

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

confirming that fsdp part looks good

Summary:

Adds support for delayed scaling in FSDP2 float8 all-gather. In detail:
1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse
   code with the dynamic version because I'd rather not deal with
   plumbing optional tensors through dynamo. We can try that in a
   separate PR later.
2. wire `Float8Linear` to use (1)
3. add weight amax syncing back, since we need it for float8 all-gather
4. add test coverage for eager mode numerics

Next up (in separate PRs) will be training run validation for numerics, and
taking a look at performance.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 10, 2024
Summary:

Adds support for delayed scaling in FSDP2 float8 all-gather. In detail:
1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse
   code with the dynamic version because I'd rather not deal with
   plumbing optional tensors through dynamo. We can try that in a
   separate PR later.
2. wire `Float8Linear` to use (1)
3. add weight amax syncing back, since we need it for float8 all-gather
4. add test coverage for eager mode numerics

Next up (in separate PRs) will be training run validation for numerics, and
taking a look at performance.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c83e4df
Pull Request resolved: #312
@vkuzo vkuzo requested a review from drisspg July 10, 2024 23:20

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
if func == torch.ops.aten.detach.default:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mostly just a nit, but any reason to special-case detach here? Alternatively, you could set it up so that every view ops automatiomatically propagates subclass-ness in the same way

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is something I wrote, I think it was just something I saw in some other subclasses. Having every view up propagate subclass-ness in the same way sounds good to me.

Copy link
Contributor

@weifengpy weifengpy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamping for the fsdp part

document 2 open questions (not blocker for this PR)

  • should we merge WeightWithDelayedFloat8CastTensor and WeightWithDynamicFloat8CastTensor into one class and add if-else to unify logic around __torch_dispatch__, fsdp_pre_all_gather/fsdp_post_all_gather. we unifed Float8Linear already
  • compare perfs between sync_float8_amax_and_scale_history and precompute_float8_dynamic_scale_for_fsdp. If they are similar, people would not need to worry about numeric problem from delayed scaling

@vkuzo
Copy link
Contributor Author

vkuzo commented Jul 12, 2024

should we merge WeightWithDelayedFloat8CastTensor and WeightWithDynamicFloat8CastTensor into one class and add if-else to unify logic around torch_dispatch, fsdp_pre_all_gather/fsdp_post_all_gather. we unifed Float8Linear already

I'm open if someone is interested in doing that in a follow-up PR. I'm not sure it will be better than what we have now though. Note that Float8Linear was unified to allow for finer grained configuration of scaling (per-tensor instead of per-module), that benefit is not on the table here.

compare perfs between sync_float8_amax_and_scale_history and precompute_float8_dynamic_scale_for_fsdp. If they are similar, people would not need to worry about numeric problem from delayed scaling

yes, that would be great! I think we can do this in follow-up PRs. Note that delayed scaling is theoretically faster than dynamic scaling (less memory reads), but performance is not optimized across the stack yet. I think it's good to have options and allow people to optimize different settings in parallel. Eventually if there is clear data that only one of these is needed, we can delete the not-needed ones.

Summary:

Adds support for delayed scaling in FSDP2 float8 all-gather. In detail:
1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse
   code with the dynamic version because I'd rather not deal with
   plumbing optional tensors through dynamo. We can try that in a
   separate PR later.
2. wire `Float8Linear` to use (1)
3. add weight amax syncing back, since we need it for float8 all-gather
4. add test coverage for eager mode numerics

Next up (in separate PRs) will be training run validation for numerics, and
taking a look at performance.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 12, 2024
Summary:

Adds support for delayed scaling in FSDP2 float8 all-gather. In detail:
1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse
   code with the dynamic version because I'd rather not deal with
   plumbing optional tensors through dynamo. We can try that in a
   separate PR later.
2. wire `Float8Linear` to use (1)
3. add weight amax syncing back, since we need it for float8 all-gather
4. add test coverage for eager mode numerics

Next up (in separate PRs) will be training run validation for numerics, and
taking a look at performance.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: cdc9d96
Pull Request resolved: #312
@vkuzo
Copy link
Contributor Author

vkuzo commented Jul 12, 2024

@vkuzo has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in de93990.


def fsdp_pre_all_gather(self, mesh):
# initialize if needed
# TODO(before land): ensure settings are consistent between Float8Linear and here
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still need to resolve this?

self._amax_buffer,
self._amax_history_buffer,
self._scale_buffer,
"max", # TODO(before land): read this from parent
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants