Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 6cba2ae

Browse files
weifengpyfacebook-github-bot
authored andcommitted
precompute scale after optimizer.step for dynamic scaling (#266)
Summary: Goal: improve float8 all-gather perf in FSDP2 by precomputing scales for all float8 params with a single all-reduce updated README for API usage: call `precompute_float8_scale_for_fsdp` inside the training loop after optimizer step ``` from float8_experimental.fsdp_utils import precompute_float8_amax_for_fsdp # inside the training loop model(input).sum().backward() optim.step() precompute_float8_scale_for_fsdp(model) ``` unit test `pytest -s test/test_fsdp2/test_fsdp2_eager.py -k test_transformer_parity_dynamic` **FSDP pre-forward**: shortend from 3ms to 1.8ms because of doing 1 all-reduce instead N small all-reduces <img width="703" alt="Screenshot 2024-05-30 at 12 38 24 AM" src="https://github.com/pytorch-labs/float8_experimental/assets/134637289/81361471-fde4-43e4-ad83-a8c5b39f0cf1"> <img width="720" alt="Screenshot 2024-05-30 at 12 48 14 AM" src="https://github.com/pytorch-labs/float8_experimental/assets/134637289/26202869-cf7d-4427-b87f-570e5dc39324"> **Pre-computing amax**: shortened from 5ms to 1.7ms, by switching from `torch._foreach_abs` + `torch.max(a)` to `torch._foreach_norm(weights, ord=math.inf)` <img width="1075" alt="Screenshot 2024-05-30 at 12 50 17 AM" src="https://github.com/pytorch-labs/float8_experimental/assets/134637289/823fb717-8f5b-42e9-afc8-6f6c34ab45b2"> <img width="1050" alt="Screenshot 2024-05-30 at 12 49 54 AM" src="https://github.com/pytorch-labs/float8_experimental/assets/134637289/5ea15f59-ec85-456b-a28c-3e672d2cdaae"> Pull Request resolved: #266 Reviewed By: vkuzo Differential Revision: D59562409 Pulled By: weifengpy fbshipit-source-id: 683c4719e20f6b30f39ca9109ee29e53981a2aec
1 parent 73fd168 commit 6cba2ae

File tree

5 files changed

+122
-12
lines changed

5 files changed

+122
-12
lines changed

README.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ This is the most accurate recipe as every tensor is scaled dynamically.
3737
from float8_experimental.float8_linear_utils import (
3838
swap_linear_with_float8_linear,
3939
)
40+
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
4041
from float8_experimental.float8_linear import Float8Linear
4142

4243
# create model
@@ -51,7 +52,18 @@ model = FSDP(model, use_orig_params=True)
5152
# optional: enable torch.compile for improved performance
5253
m = torch.compile(m)
5354

54-
# train/finetune (not shown)
55+
# toy training loop
56+
for _ in range(N_ITER):
57+
optimizer.zero_grad()
58+
y = m(x)
59+
y.sum().backward()
60+
optimizer.step()
61+
62+
# specific to fsdp2 + dynamic scaling, when fp8 all-gather is turned on
63+
# this method is optional but is highly recommended for performance
64+
# it calcuclates scales for all parameters in a single all-reduce
65+
precompute_float8_dynamic_scale_for_fsdp(model)
66+
5567
```
5668

5769
## float8 linear with delayed scaling

float8_experimental/float8_dynamic_utils.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,12 @@ def cast_to_float8_e5m2_dynamic_bw(
8282

8383
class WeightWithDynamicFloat8CastTensor(torch.Tensor):
8484
@staticmethod
85-
def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
85+
def __new__(
86+
cls,
87+
tensor: torch.Tensor,
88+
mm_config: ScaledMMConfig,
89+
precomputed_scale: Optional[torch.Tensor] = None,
90+
):
8691
return torch.Tensor._make_wrapper_subclass(
8792
cls,
8893
tensor.size(),
@@ -96,9 +101,18 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
96101
requires_grad=tensor.requires_grad,
97102
)
98103

99-
def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig):
104+
def __init__(
105+
self,
106+
tensor: torch.Tensor,
107+
mm_config: ScaledMMConfig,
108+
precomputed_scale: Optional[torch.Tensor] = None,
109+
):
100110
self._tensor = tensor
101111
self._mm_config = mm_config
112+
# for dynamic scaling
113+
# `precompute_float8_dynamic_scale_for_fsdp` calculates scales
114+
# for all float8 parameters after optimizer step
115+
self._precomputed_scale = precomputed_scale
102116

103117
@classmethod
104118
def __torch_dispatch__(cls, func, types, args, kwargs=None):
@@ -127,20 +141,35 @@ def unwrap(t):
127141
)
128142

129143
def __tensor_flatten__(self):
130-
return ["_tensor"], self._mm_config
144+
if self._precomputed_scale:
145+
return ["_tensor", "_precomputed_scale"], self._mm_config
146+
else:
147+
return ["_tensor"], self._mm_config
131148

132149
@staticmethod
133150
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
134151
mm_config = flatten_spec
135-
return WeightWithDynamicFloat8CastTensor(inner_tensors["_tensor"], mm_config)
152+
return WeightWithDynamicFloat8CastTensor(
153+
inner_tensors["_tensor"],
154+
mm_config,
155+
getattr(inner_tensors, "_precomputed_scale", None),
156+
)
136157

137158
def __repr__(self):
138159
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"
139160

140161
def fsdp_pre_all_gather(self, mesh):
141-
float8_tensor = cast_to_float8_e4m3_dynamic(
142-
self._tensor, self._mm_config, reduce_amax=True
143-
)
162+
if self._precomputed_scale is not None:
163+
float8_tensor = Float8Tensor.to_float8(
164+
self._tensor,
165+
self._precomputed_scale,
166+
torch.float8_e4m3fn,
167+
mm_config=self._mm_config,
168+
)
169+
else:
170+
float8_tensor = cast_to_float8_e4m3_dynamic(
171+
self._tensor, self._mm_config, reduce_amax=True
172+
)
144173
return (float8_tensor._data,), (float8_tensor._scale,)
145174

146175
def fsdp_post_all_gather(

float8_experimental/fsdp_utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import math
2+
from typing import List
3+
4+
import torch
5+
import torch.nn as nn
6+
from float8_experimental.float8_dynamic_utils import WeightWithDynamicFloat8CastTensor
7+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
8+
from float8_experimental.float8_utils import EPS
9+
10+
11+
@torch.no_grad()
12+
def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
13+
"""
14+
Calculate scale dynamically for all float8 parameters.
15+
This should be run after the optimizer step. It performs a single all-reduce to compute the
16+
scales for all float8 weights.
17+
Example usage:
18+
model(input).sum().backward()
19+
optim.step()
20+
precompute_float8_dynamic_scale_for_fsdp(model)
21+
"""
22+
from torch.distributed._tensor import DTensor
23+
24+
if any(
25+
isinstance(m, Float8Linear) and m.scaling_type_w is TensorScalingType.DELAYED
26+
for m in module.modules()
27+
):
28+
raise NotImplementedError("Only supports delayed scaling")
29+
float8_linears: List[Float8Linear] = [
30+
m
31+
for m in module.modules()
32+
if isinstance(m, Float8Linear)
33+
and isinstance(m.weight, DTensor)
34+
and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor)
35+
]
36+
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears]
37+
38+
if not weights:
39+
return
40+
41+
# inf-norm is equivalent to max(abs(w))
42+
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
43+
amax_tensor = torch.vstack(max_weights) # Partial
44+
# clamp is dispatched through DTensor
45+
# it will issue a single all-reduce
46+
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
47+
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
48+
if amax_tensor.dtype is torch.float16:
49+
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
50+
scales = torch.split(scale_tensor, 1) # Replicate
51+
for scale, float8_linear in zip(scales, float8_linears):
52+
float8_linear.weight._local_tensor._precomputed_scale = scale._local_tensor

test/test_fsdp2/test_fsdp2_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
import torch.distributed as dist
88
import torch.nn as nn
9+
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
910

1011

1112
def check_parity_no_mp(
@@ -15,6 +16,7 @@ def check_parity_no_mp(
1516
fsdp_model: nn.Module,
1617
fsdp_optim: torch.optim.Optimizer,
1718
local_inp: torch.Tensor,
19+
precompute: bool = False,
1820
):
1921
for iter_idx in range(10):
2022
losses: List[torch.Tensor] = []
@@ -28,6 +30,8 @@ def check_parity_no_mp(
2830
param.grad.div_(dist.get_world_size())
2931
# TODO(future): add amax syncing once delayed scaling is supported
3032
optim.step()
33+
if model is fsdp_model and precompute:
34+
precompute_float8_dynamic_scale_for_fsdp(model)
3135
test_cls.assertEqual(losses[0], losses[1])
3236

3337

test/test_fsdp2/test_fsdp2_eager.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,21 @@ def world_size(self) -> int:
8686

8787
@skip_if_lt_x_gpu(2)
8888
def test_transformer_parity_dynamic(self):
89-
for enable_fsdp_fp8_all_gather in [False, True]:
90-
self._test_transformer_parity_dynamic(enable_fsdp_fp8_all_gather)
89+
self.run_subtests(
90+
{
91+
"enable_fsdp_fp8_all_gather": [False, True],
92+
"precompute": [False, True],
93+
},
94+
self._test_transformer_parity_dynamic,
95+
)
9196

92-
def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
97+
def _test_transformer_parity_dynamic(
98+
self,
99+
enable_fsdp_fp8_all_gather: bool,
100+
precompute: bool,
101+
):
102+
if not enable_fsdp_fp8_all_gather and precompute:
103+
return
93104
# NOTE: Weight-tying does not compose with fp8 all-gather because the
94105
# embedding weight and output linear weight are tied but only the
95106
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
@@ -109,7 +120,9 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
109120
local_inp = torch.randint(
110121
0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda"
111122
)
112-
check_parity_no_mp(self, ref_module, ref_optim, module, optim, local_inp)
123+
check_parity_no_mp(
124+
self, ref_module, ref_optim, module, optim, local_inp, precompute
125+
)
113126

114127
@skip_if_lt_x_gpu(2)
115128
def test_transformer_memory(self):

0 commit comments

Comments
 (0)