Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

[POC] Added fp8 all-gather extensions #216

Closed
wants to merge 6 commits into from
Closed

Conversation

awgu
Copy link

@awgu awgu commented Feb 14, 2024

Stack from ghstack (oldest at bottom):

Overview

This PR shows the prototype for enabling fp8 all-gather for Float8Linear.weight and Float8DynamicLinear.weight. This requires changes from pytorch/pytorch#119378.

The approach is to change the weight tensor into a tensor subclass that defines two methods fsdp_pre_all_gather() and fsdp_post_all_gather(). We currently prefer this approach since subclasses are the blessed approach for extending at the tensor level. However, we are evaluating the implications on both eager performance and compile.

See #201 for some more notes on per-parameter FSDP and fp8.

torch.compile w/o fp8 all-gather

TL;DR only transformer_block.forward = torch.compile(transformer_block.forward) works today.

Delayed Scaling Dynamic Scaling
Compile Transformer Block forward 🙁 requires disabling amax init
🙁 requires disabling pre/post-forward
✅ one graph per transformer block
✅ one graph per transformer block
❌ error in float8_mm if compiling output projection
Compile Transformer 🙁 requires disabling amax init
❌ unexpected graph breaks
❌ error in float8_mm compile

torch.compile w/ fp8 all-gather

Context: Per-parameter FSDP runs all-gather in a pre-forward hook and frees parameters in a post-forward hook. Using transformer_block.forward = torch.compile(transformer_block.forward) does not compile the hooks, so we distinguish between two cases when doing this block-level compile: including and not including hooks.

One way to emulate including the hooks is to change per-parameter FSDP to override forward() instead of use hooks, in which case transformer_block.forward would include FSDP's pre/post-forward logic directly. We have not investigated this yet.

Delayed Scaling Dynamic Scaling
Compile Transformer Block forward w/o Forward Hooks 🙁 requires disabling amax init
🙁 requires disabling pre/post-forward
✅ one graph per transformer block
✅ one graph per transformer block
❌ error in float8_mm if compiling output projection
Compile Transformer 🙁 requires disabling amax init
❌ error in pre-all-gather compile
❌ error in float8_mm compile
Compile Transformer Block incl. Forward Hooks

awgu pushed a commit that referenced this pull request Feb 14, 2024
ghstack-source-id: ff0433d
Pull Request resolved: #216
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 14, 2024
### Overview
This PR shows the prototype for enabling fp8 all-gather for `Float8Linear.weight` and `Float8DynamicLinear.weight`. This requires changes from pytorch/pytorch#119378.

The approach is to change the `weight` tensor into a tensor subclass that defines two methods `fsdp_pre_all_gather()` and `fsdp_post_all_gather()`. We currently prefer this approach since subclasses are the blessed approach for extending at the tensor level. However, we are evaluating the implications on both eager performance and compile.

See #201 for some more notes on per-parameter FSDP and fp8.

### `torch.compile` w/o fp8 all-gather
**TL;DR** only `transformer_block.forward = torch.compile(transformer_block.forward)` works today.
|                           | Delayed Scaling                                          | Dynamic Scaling                                                      |
|---------------------------|----------------------------------------------------------|----------------------------------------------------------------------|
| Compile Transformer Block `forward` | 🙁 requires disabling amax init<br> 🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection |
| Compile Transformer       | 🙁 requires disabling amax init<br> ❌ unexpected graph breaks                                | ❌ error in float8_mm compile                                         |

### `torch.compile` w/ fp8 all-gather
Context: Per-parameter FSDP runs all-gather in a pre-forward hook and frees parameters in a post-forward hook. Using `transformer_block.forward = torch.compile(transformer_block.forward)` does not compile the hooks, so we distinguish between two cases when doing this block-level compile: including and not including hooks.

One way to emulate including the hooks is to change per-parameter FSDP to override `forward()` instead of use hooks, in which case `transformer_block.forward` would include FSDP's pre/post-forward logic directly. We have not investigated this yet.


|                           | Delayed Scaling                                                  | Dynamic Scaling                                                                       |
|---------------------------|------------------------------------------------------------------|---------------------------------------------------------------------------------------|
| Compile Transformer Block `forward` w/o Forward Hooks | 🙁 requires disabling amax init <br>🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block         | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection |
| Compile Transformer       | 🙁 requires disabling amax init <br>❌ error in pre-all-gather compile | ❌ error in float8_mm compile                                                          |
| Compile Transformer Block incl. Forward Hooks |                                                                  |                                                                                       |

[ghstack-poisoned]
### Overview
This PR shows the prototype for enabling fp8 all-gather for `Float8Linear.weight` and `Float8DynamicLinear.weight`. This requires changes from pytorch/pytorch#119378.

The approach is to change the `weight` tensor into a tensor subclass that defines two methods `fsdp_pre_all_gather()` and `fsdp_post_all_gather()`. We currently prefer this approach since subclasses are the blessed approach for extending at the tensor level. However, we are evaluating the implications on both eager performance and compile.

See #201 for some more notes on per-parameter FSDP and fp8.

### `torch.compile` w/o fp8 all-gather
**TL;DR** only `transformer_block.forward = torch.compile(transformer_block.forward)` works today.
|                           | Delayed Scaling                                          | Dynamic Scaling                                                      |
|---------------------------|----------------------------------------------------------|----------------------------------------------------------------------|
| Compile Transformer Block `forward` | 🙁 requires disabling amax init<br> 🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection |
| Compile Transformer       | 🙁 requires disabling amax init<br> ❌ unexpected graph breaks                                | ❌ error in float8_mm compile                                         |

### `torch.compile` w/ fp8 all-gather
Context: Per-parameter FSDP runs all-gather in a pre-forward hook and frees parameters in a post-forward hook. Using `transformer_block.forward = torch.compile(transformer_block.forward)` does not compile the hooks, so we distinguish between two cases when doing this block-level compile: including and not including hooks.

One way to emulate including the hooks is to change per-parameter FSDP to override `forward()` instead of use hooks, in which case `transformer_block.forward` would include FSDP's pre/post-forward logic directly. We have not investigated this yet.


|                           | Delayed Scaling                                                  | Dynamic Scaling                                                                       |
|---------------------------|------------------------------------------------------------------|---------------------------------------------------------------------------------------|
| Compile Transformer Block `forward` w/o Forward Hooks | 🙁 requires disabling amax init <br>🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block         | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection |
| Compile Transformer       | 🙁 requires disabling amax init <br>❌ error in pre-all-gather compile | ❌ error in float8_mm compile                                                          |
| Compile Transformer Block incl. Forward Hooks |                                                                  |                                                                                       |

[ghstack-poisoned]
Andrew Gu added 3 commits February 16, 2024 07:36
### Overview
This PR shows the prototype for enabling fp8 all-gather for `Float8Linear.weight` and `Float8DynamicLinear.weight`. This requires changes from pytorch/pytorch#119378.

The approach is to change the `weight` tensor into a tensor subclass that defines two methods `fsdp_pre_all_gather()` and `fsdp_post_all_gather()`. We currently prefer this approach since subclasses are the blessed approach for extending at the tensor level. However, we are evaluating the implications on both eager performance and compile.

See #201 for some more notes on per-parameter FSDP and fp8.

### `torch.compile` w/o fp8 all-gather
**TL;DR** only `transformer_block.forward = torch.compile(transformer_block.forward)` works today.
|                           | Delayed Scaling                                          | Dynamic Scaling                                                      |
|---------------------------|----------------------------------------------------------|----------------------------------------------------------------------|
| Compile Transformer Block `forward` | 🙁 requires disabling amax init<br> 🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection |
| Compile Transformer       | 🙁 requires disabling amax init<br> ❌ unexpected graph breaks                                | ❌ error in float8_mm compile                                         |

### `torch.compile` w/ fp8 all-gather
Context: Per-parameter FSDP runs all-gather in a pre-forward hook and frees parameters in a post-forward hook. Using `transformer_block.forward = torch.compile(transformer_block.forward)` does not compile the hooks, so we distinguish between two cases when doing this block-level compile: including and not including hooks.

One way to emulate including the hooks is to change per-parameter FSDP to override `forward()` instead of use hooks, in which case `transformer_block.forward` would include FSDP's pre/post-forward logic directly. We have not investigated this yet.


|                           | Delayed Scaling                                                  | Dynamic Scaling                                                                       |
|---------------------------|------------------------------------------------------------------|---------------------------------------------------------------------------------------|
| Compile Transformer Block `forward` w/o Forward Hooks | 🙁 requires disabling amax init <br>🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block         | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection |
| Compile Transformer       | 🙁 requires disabling amax init <br>❌ error in pre-all-gather compile | ❌ error in float8_mm compile                                                          |
| Compile Transformer Block incl. Forward Hooks |                                                                  |                                                                                       |

[ghstack-poisoned]
### Overview
This PR shows the prototype for enabling fp8 all-gather for `Float8Linear.weight` and `Float8DynamicLinear.weight`. This requires changes from pytorch/pytorch#119378.

The approach is to change the `weight` tensor into a tensor subclass that defines two methods `fsdp_pre_all_gather()` and `fsdp_post_all_gather()`. We currently prefer this approach since subclasses are the blessed approach for extending at the tensor level. However, we are evaluating the implications on both eager performance and compile.

See #201 for some more notes on per-parameter FSDP and fp8.

### `torch.compile` w/o fp8 all-gather
**TL;DR** only `transformer_block.forward = torch.compile(transformer_block.forward)` works today.
|                           | Delayed Scaling                                          | Dynamic Scaling                                                      |
|---------------------------|----------------------------------------------------------|----------------------------------------------------------------------|
| Compile Transformer Block `forward` | 🙁 requires disabling amax init<br> 🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection |
| Compile Transformer       | 🙁 requires disabling amax init<br> ❌ unexpected graph breaks                                | ❌ error in float8_mm compile                                         |

### `torch.compile` w/ fp8 all-gather
Context: Per-parameter FSDP runs all-gather in a pre-forward hook and frees parameters in a post-forward hook. Using `transformer_block.forward = torch.compile(transformer_block.forward)` does not compile the hooks, so we distinguish between two cases when doing this block-level compile: including and not including hooks.

One way to emulate including the hooks is to change per-parameter FSDP to override `forward()` instead of use hooks, in which case `transformer_block.forward` would include FSDP's pre/post-forward logic directly. We have not investigated this yet.


|                           | Delayed Scaling                                                  | Dynamic Scaling                                                                       |
|---------------------------|------------------------------------------------------------------|---------------------------------------------------------------------------------------|
| Compile Transformer Block `forward` w/o Forward Hooks | 🙁 requires disabling amax init <br>🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block         | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection |
| Compile Transformer       | 🙁 requires disabling amax init <br>❌ error in pre-all-gather compile | ❌ error in float8_mm compile                                                          |
| Compile Transformer Block incl. Forward Hooks |                                                                  |                                                                                       |

[ghstack-poisoned]
### Overview
This PR shows the prototype for enabling fp8 all-gather for `Float8Linear.weight` and `Float8DynamicLinear.weight`. This requires changes from pytorch/pytorch#119378.

The approach is to change the `weight` tensor into a tensor subclass that defines two methods `fsdp_pre_all_gather()` and `fsdp_post_all_gather()`. We currently prefer this approach since subclasses are the blessed approach for extending at the tensor level. However, we are evaluating the implications on both eager performance and compile.

See #201 for some more notes on per-parameter FSDP and fp8.

### `torch.compile` w/o fp8 all-gather
**TL;DR** only `transformer_block.forward = torch.compile(transformer_block.forward)` works today.
|                           | Delayed Scaling                                          | Dynamic Scaling                                                      |
|---------------------------|----------------------------------------------------------|----------------------------------------------------------------------|
| Compile Transformer Block `forward` | 🙁 requires disabling amax init<br> 🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection |
| Compile Transformer       | 🙁 requires disabling amax init<br> ❌ unexpected graph breaks                                | ❌ error in float8_mm compile                                         |

### `torch.compile` w/ fp8 all-gather
Context: Per-parameter FSDP runs all-gather in a pre-forward hook and frees parameters in a post-forward hook. Using `transformer_block.forward = torch.compile(transformer_block.forward)` does not compile the hooks, so we distinguish between two cases when doing this block-level compile: including and not including hooks.

One way to emulate including the hooks is to change per-parameter FSDP to override `forward()` instead of use hooks, in which case `transformer_block.forward` would include FSDP's pre/post-forward logic directly. We have not investigated this yet.


|                           | Delayed Scaling                                                  | Dynamic Scaling                                                                       |
|---------------------------|------------------------------------------------------------------|---------------------------------------------------------------------------------------|
| Compile Transformer Block `forward` w/o Forward Hooks | 🙁 requires disabling amax init <br>🙁 requires disabling pre/post-forward <br> ✅ one graph per transformer block         | ✅ one graph per transformer block <br>❌ error in float8_mm if compiling output projection |
| Compile Transformer       | 🙁 requires disabling amax init <br>❌ error in pre-all-gather compile | ❌ error in float8_mm compile                                                          |
| Compile Transformer Block incl. Forward Hooks |                                                                  |                                                                                       |

[ghstack-poisoned]
@awgu awgu closed this Feb 27, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants