This repository was archived by the owner on Aug 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 19
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[ghstack-poisoned]
This was referenced Feb 14, 2024
### Overview This PR shows the prototype for enabling fp8 all-gather for `Float8Linear.weight` and `Float8DynamicLinear.weight`. This requires changes from pytorch/pytorch#119378. The approach is to change the `weight` tensor into a tensor subclass that defines two methods `fsdp_pre_all_gather()` and `fsdp_post_all_gather()`. We currently prefer this approach since subclasses are the blessed approach for extending at the tensor level. However, we are evaluating the implications on both eager performance and compile. See #201 for some more notes on per-parameter FSDP and fp8. ### `torch.compile` w/o fp8 all-gather **TL;DR** only `transformer_block.forward = torch.compile(transformer_block.forward)` works today. | | Delayed Scaling | Dynamic Scaling | |---------------------------|----------------------------------------------------------|----------------------------------------------------------------------| | Compile Transformer Block `forward` | 🙁 requires disabling amax init<br> 🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection | | Compile Transformer | 🙁 requires disabling amax init<br> ❌ unexpected graph breaks | ❌ error in float8_mm compile | ### `torch.compile` w/ fp8 all-gather Context: Per-parameter FSDP runs all-gather in a pre-forward hook and frees parameters in a post-forward hook. Using `transformer_block.forward = torch.compile(transformer_block.forward)` does not compile the hooks, so we distinguish between two cases when doing this block-level compile: including and not including hooks. One way to emulate including the hooks is to change per-parameter FSDP to override `forward()` instead of use hooks, in which case `transformer_block.forward` would include FSDP's pre/post-forward logic directly. We have not investigated this yet. | | Delayed Scaling | Dynamic Scaling | |---------------------------|------------------------------------------------------------------|---------------------------------------------------------------------------------------| | Compile Transformer Block `forward` w/o Forward Hooks | 🙁 requires disabling amax init <br>🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection | | Compile Transformer | 🙁 requires disabling amax init <br>❌ error in pre-all-gather compile | ❌ error in float8_mm compile | | Compile Transformer Block incl. Forward Hooks | | | [ghstack-poisoned]
### Overview This PR shows the prototype for enabling fp8 all-gather for `Float8Linear.weight` and `Float8DynamicLinear.weight`. This requires changes from pytorch/pytorch#119378. The approach is to change the `weight` tensor into a tensor subclass that defines two methods `fsdp_pre_all_gather()` and `fsdp_post_all_gather()`. We currently prefer this approach since subclasses are the blessed approach for extending at the tensor level. However, we are evaluating the implications on both eager performance and compile. See #201 for some more notes on per-parameter FSDP and fp8. ### `torch.compile` w/o fp8 all-gather **TL;DR** only `transformer_block.forward = torch.compile(transformer_block.forward)` works today. | | Delayed Scaling | Dynamic Scaling | |---------------------------|----------------------------------------------------------|----------------------------------------------------------------------| | Compile Transformer Block `forward` | 🙁 requires disabling amax init<br> 🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection | | Compile Transformer | 🙁 requires disabling amax init<br> ❌ unexpected graph breaks | ❌ error in float8_mm compile | ### `torch.compile` w/ fp8 all-gather Context: Per-parameter FSDP runs all-gather in a pre-forward hook and frees parameters in a post-forward hook. Using `transformer_block.forward = torch.compile(transformer_block.forward)` does not compile the hooks, so we distinguish between two cases when doing this block-level compile: including and not including hooks. One way to emulate including the hooks is to change per-parameter FSDP to override `forward()` instead of use hooks, in which case `transformer_block.forward` would include FSDP's pre/post-forward logic directly. We have not investigated this yet. | | Delayed Scaling | Dynamic Scaling | |---------------------------|------------------------------------------------------------------|---------------------------------------------------------------------------------------| | Compile Transformer Block `forward` w/o Forward Hooks | 🙁 requires disabling amax init <br>🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection | | Compile Transformer | 🙁 requires disabling amax init <br>❌ error in pre-all-gather compile | ❌ error in float8_mm compile | | Compile Transformer Block incl. Forward Hooks | | | [ghstack-poisoned]
### Overview This PR shows the prototype for enabling fp8 all-gather for `Float8Linear.weight` and `Float8DynamicLinear.weight`. This requires changes from pytorch/pytorch#119378. The approach is to change the `weight` tensor into a tensor subclass that defines two methods `fsdp_pre_all_gather()` and `fsdp_post_all_gather()`. We currently prefer this approach since subclasses are the blessed approach for extending at the tensor level. However, we are evaluating the implications on both eager performance and compile. See #201 for some more notes on per-parameter FSDP and fp8. ### `torch.compile` w/o fp8 all-gather **TL;DR** only `transformer_block.forward = torch.compile(transformer_block.forward)` works today. | | Delayed Scaling | Dynamic Scaling | |---------------------------|----------------------------------------------------------|----------------------------------------------------------------------| | Compile Transformer Block `forward` | 🙁 requires disabling amax init<br> 🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection | | Compile Transformer | 🙁 requires disabling amax init<br> ❌ unexpected graph breaks | ❌ error in float8_mm compile | ### `torch.compile` w/ fp8 all-gather Context: Per-parameter FSDP runs all-gather in a pre-forward hook and frees parameters in a post-forward hook. Using `transformer_block.forward = torch.compile(transformer_block.forward)` does not compile the hooks, so we distinguish between two cases when doing this block-level compile: including and not including hooks. One way to emulate including the hooks is to change per-parameter FSDP to override `forward()` instead of use hooks, in which case `transformer_block.forward` would include FSDP's pre/post-forward logic directly. We have not investigated this yet. | | Delayed Scaling | Dynamic Scaling | |---------------------------|------------------------------------------------------------------|---------------------------------------------------------------------------------------| | Compile Transformer Block `forward` w/o Forward Hooks | 🙁 requires disabling amax init <br>🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection | | Compile Transformer | 🙁 requires disabling amax init <br>❌ error in pre-all-gather compile | ❌ error in float8_mm compile | | Compile Transformer Block incl. Forward Hooks | | | [ghstack-poisoned]
### Overview This PR shows the prototype for enabling fp8 all-gather for `Float8Linear.weight` and `Float8DynamicLinear.weight`. This requires changes from pytorch/pytorch#119378. The approach is to change the `weight` tensor into a tensor subclass that defines two methods `fsdp_pre_all_gather()` and `fsdp_post_all_gather()`. We currently prefer this approach since subclasses are the blessed approach for extending at the tensor level. However, we are evaluating the implications on both eager performance and compile. See #201 for some more notes on per-parameter FSDP and fp8. ### `torch.compile` w/o fp8 all-gather **TL;DR** only `transformer_block.forward = torch.compile(transformer_block.forward)` works today. | | Delayed Scaling | Dynamic Scaling | |---------------------------|----------------------------------------------------------|----------------------------------------------------------------------| | Compile Transformer Block `forward` | 🙁 requires disabling amax init<br> 🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection | | Compile Transformer | 🙁 requires disabling amax init<br> ❌ unexpected graph breaks | ❌ error in float8_mm compile | ### `torch.compile` w/ fp8 all-gather Context: Per-parameter FSDP runs all-gather in a pre-forward hook and frees parameters in a post-forward hook. Using `transformer_block.forward = torch.compile(transformer_block.forward)` does not compile the hooks, so we distinguish between two cases when doing this block-level compile: including and not including hooks. One way to emulate including the hooks is to change per-parameter FSDP to override `forward()` instead of use hooks, in which case `transformer_block.forward` would include FSDP's pre/post-forward logic directly. We have not investigated this yet. | | Delayed Scaling | Dynamic Scaling | |---------------------------|------------------------------------------------------------------|---------------------------------------------------------------------------------------| | Compile Transformer Block `forward` w/o Forward Hooks | 🙁 requires disabling amax init <br>🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection | | Compile Transformer | 🙁 requires disabling amax init <br>❌ error in pre-all-gather compile | ❌ error in float8_mm compile | | Compile Transformer Block incl. Forward Hooks | | | [ghstack-poisoned]
### Overview This PR shows the prototype for enabling fp8 all-gather for `Float8Linear.weight` and `Float8DynamicLinear.weight`. This requires changes from pytorch/pytorch#119378. The approach is to change the `weight` tensor into a tensor subclass that defines two methods `fsdp_pre_all_gather()` and `fsdp_post_all_gather()`. We currently prefer this approach since subclasses are the blessed approach for extending at the tensor level. However, we are evaluating the implications on both eager performance and compile. See #201 for some more notes on per-parameter FSDP and fp8. ### `torch.compile` w/o fp8 all-gather **TL;DR** only `transformer_block.forward = torch.compile(transformer_block.forward)` works today. | | Delayed Scaling | Dynamic Scaling | |---------------------------|----------------------------------------------------------|----------------------------------------------------------------------| | Compile Transformer Block `forward` | 🙁 requires disabling amax init<br> 🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection | | Compile Transformer | 🙁 requires disabling amax init<br> ❌ unexpected graph breaks | ❌ error in float8_mm compile | ### `torch.compile` w/ fp8 all-gather Context: Per-parameter FSDP runs all-gather in a pre-forward hook and frees parameters in a post-forward hook. Using `transformer_block.forward = torch.compile(transformer_block.forward)` does not compile the hooks, so we distinguish between two cases when doing this block-level compile: including and not including hooks. One way to emulate including the hooks is to change per-parameter FSDP to override `forward()` instead of use hooks, in which case `transformer_block.forward` would include FSDP's pre/post-forward logic directly. We have not investigated this yet. | | Delayed Scaling | Dynamic Scaling | |---------------------------|------------------------------------------------------------------|---------------------------------------------------------------------------------------| | Compile Transformer Block `forward` w/o Forward Hooks | 🙁 requires disabling amax init <br>🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection | | Compile Transformer | 🙁 requires disabling amax init <br>❌ error in pre-all-gather compile | ❌ error in float8_mm compile | | Compile Transformer Block incl. Forward Hooks | | | [ghstack-poisoned]
Sign up for free
to subscribe to this conversation on GitHub.
Already have an account?
Sign in.
Labels
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Stack from ghstack (oldest at bottom):
module
arg fromfsdp_pre_all_gather
#217use_activation_hooks: bool
to swap #214amax_and_scale_synced
unconditionally #220Overview
This PR shows the prototype for enabling fp8 all-gather for
Float8Linear.weight
andFloat8DynamicLinear.weight
. This requires changes from pytorch/pytorch#119378.The approach is to change the
weight
tensor into a tensor subclass that defines two methodsfsdp_pre_all_gather()
andfsdp_post_all_gather()
. We currently prefer this approach since subclasses are the blessed approach for extending at the tensor level. However, we are evaluating the implications on both eager performance and compile.See #201 for some more notes on per-parameter FSDP and fp8.
torch.compile
w/o fp8 all-gatherTL;DR only
transformer_block.forward = torch.compile(transformer_block.forward)
works today.forward
🙁 requires disabling pre/post-forward
✅ one graph per transformer block
❌ error in float8_mm if compiling output projection
❌ unexpected graph breaks
torch.compile
w/ fp8 all-gatherContext: Per-parameter FSDP runs all-gather in a pre-forward hook and frees parameters in a post-forward hook. Using
transformer_block.forward = torch.compile(transformer_block.forward)
does not compile the hooks, so we distinguish between two cases when doing this block-level compile: including and not including hooks.One way to emulate including the hooks is to change per-parameter FSDP to override
forward()
instead of use hooks, in which casetransformer_block.forward
would include FSDP's pre/post-forward logic directly. We have not investigated this yet.forward
w/o Forward Hooks🙁 requires disabling pre/post-forward
✅ one graph per transformer block
❌ error in float8_mm if compiling output projection
❌ error in pre-all-gather compile