Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 9fcf40a

Browse files
committed
support delayed scaling of weight in float8 all-gather
Summary: Adds support for delayed scaling in FSDP2 float8 all-gather. In detail: 1. add `WeightWithDelayedFloat8CastTensor`, note that we don't reuse code with the dynamic version because I'd rather not deal with plumbing optional tensors through dynamo. We can try that in a separate PR later. 2. wire `Float8Linear` to use (1) 3. add weight amax syncing back, since we need it for float8 all-gather 4. add test coverage for eager mode numerics Next up (in separate PRs) will be training run validation for numerics, and taking a look at performance. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: f1707c1 Pull Request resolved: #312
1 parent 1ba79c8 commit 9fcf40a

File tree

5 files changed

+315
-52
lines changed

5 files changed

+315
-52
lines changed

float8_experimental/float8_linear.py

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@
3434
tensor_to_amax,
3535
)
3636

37-
from float8_experimental.fsdp_utils import WeightWithDynamicFloat8CastTensor
37+
from float8_experimental.fsdp_utils import (
38+
WeightWithDelayedFloat8CastTensor,
39+
WeightWithDynamicFloat8CastTensor,
40+
)
3841

3942

4043
def _maybe_initialize_amaxes_scales_for_float8_cast(
@@ -316,25 +319,28 @@ def cast_w_to_float8(
316319
self, w: torch.Tensor, is_amax_initialized: bool
317320
) -> torch.Tensor:
318321
if self.scaling_type_w is TensorScalingType.DELAYED:
319-
scale_fn_name = self.recipe.scale_fn_name
320-
_maybe_initialize_amaxes_scales_for_float8_cast(
321-
w,
322-
self.fp8_amax_w,
323-
self.fp8_amax_history_w,
324-
self.fp8_scale_w,
325-
scale_fn_name,
326-
e4m3_dtype,
327-
is_amax_initialized,
328-
reduce_amax=False,
329-
)
330-
331-
w_fp8 = Float8Tensor.to_float8(
332-
w,
333-
self.fp8_scale_w,
334-
e4m3_dtype,
335-
self.fp8_amax_w,
336-
self.forward_config,
337-
)
322+
if isinstance(self.weight, Float8Tensor): # cast by FSDP
323+
w_fp8 = self.weight
324+
else:
325+
scale_fn_name = self.recipe.scale_fn_name
326+
_maybe_initialize_amaxes_scales_for_float8_cast(
327+
w,
328+
self.fp8_amax_w,
329+
self.fp8_amax_history_w,
330+
self.fp8_scale_w,
331+
scale_fn_name,
332+
e4m3_dtype,
333+
is_amax_initialized,
334+
reduce_amax=False,
335+
)
336+
337+
w_fp8 = Float8Tensor.to_float8(
338+
w,
339+
self.fp8_scale_w,
340+
e4m3_dtype,
341+
self.fp8_amax_w,
342+
self.forward_config,
343+
)
338344
else:
339345
assert self.scaling_type_w is TensorScalingType.DYNAMIC
340346
# TODO(future): also support FSDP integration in delayed scaling path
@@ -436,18 +442,36 @@ def from_float(
436442
scaling_type_dL_dY=scaling_type_dL_dY,
437443
emulate=emulate,
438444
)
439-
if (
440-
scaling_type_w == TensorScalingType.DYNAMIC
441-
and config.enable_fsdp_fp8_all_gather
442-
):
443-
new_mod.weight = torch.nn.Parameter(
444-
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
445-
)
446-
else:
447-
assert not config.enable_fsdp_fp8_all_gather, "unsupported"
448-
new_mod.weight = mod.weight
445+
new_mod.weight = mod.weight
449446
new_mod.bias = mod.bias
450447
# need to create buffers again when moving from meta device to
451448
# real device
452449
new_mod.create_buffers()
450+
451+
# If FSDP float8 all-gather is on, wrap the weight in a float8-aware
452+
# tensor subclass. This must happen last because:
453+
# 1. weight needs to be on the correct device to create the buffers
454+
# 2. buffers need to be already created for the delayed scaling version
455+
# of the weight wrapper to be initialized
456+
if config.enable_fsdp_fp8_all_gather:
457+
if scaling_type_w is TensorScalingType.DYNAMIC:
458+
new_mod.weight = torch.nn.Parameter(
459+
WeightWithDynamicFloat8CastTensor(
460+
new_mod.weight,
461+
new_mod.forward_config,
462+
)
463+
)
464+
else:
465+
assert scaling_type_w is TensorScalingType.DELAYED
466+
new_mod.weight = torch.nn.Parameter(
467+
WeightWithDelayedFloat8CastTensor(
468+
new_mod.weight,
469+
new_mod.fp8_amax_w,
470+
new_mod.fp8_amax_history_w,
471+
new_mod.fp8_scale_w,
472+
new_mod.forward_config,
473+
new_mod.is_amax_initialized,
474+
)
475+
)
476+
453477
return new_mod

float8_experimental/float8_linear_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -291,11 +291,10 @@ def inner_func():
291291
), "Mismatched lengths of amax tensors."
292292

293293
if dist.is_initialized():
294-
# Combine all the amax tensors into one tensor and reduce it
295-
# Note: do not reduce the weight values, because FSDP already ensures
296-
# the weight values on all ranks are the same after all-gather.
297294
all_amax_tensors = torch.cat(
298-
fp8_amax_x_tensor_list + fp8_amax_dL_dY_tensor_list
295+
fp8_amax_x_tensor_list
296+
+ fp8_amax_w_tensor_list
297+
+ fp8_amax_dL_dY_tensor_list
299298
)
300299
all_reduced_amax_tensor = all_reduce(
301300
all_amax_tensors, "MAX", list(range(dist.get_world_size()))
@@ -304,12 +303,14 @@ def inner_func():
304303
all_reduced_amax_tensor = all_reduced_amax_tensor.wait()
305304

306305
(
307-
reduced_fp8_amax_tensor,
306+
reduced_fp8_amax_x_tensor,
307+
reduced_fp8_amax_w_tensor,
308308
reduced_fp8_amax_dL_dY_tensor,
309309
) = torch.split(all_reduced_amax_tensor, len(fp8_amax_x_tensor_list))
310310

311311
for idx, child in enumerate(fp8_layers):
312-
child.fp8_amax_x.copy_(reduced_fp8_amax_tensor[idx])
312+
child.fp8_amax_x.copy_(reduced_fp8_amax_x_tensor[idx])
313+
child.fp8_amax_w.copy_(reduced_fp8_amax_w_tensor[idx])
313314
child.fp8_amax_dL_dY.copy_(reduced_fp8_amax_dL_dY_tensor[idx])
314315

315316
# We create two stacked tensor groups, one for the amax history and one for the current scales

float8_experimental/fsdp_utils.py

Lines changed: 181 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,17 @@
66

77
from typing import Any, Optional, Tuple
88

9+
import float8_experimental.config as config
10+
911
import torch
1012
import torch.utils._pytree as pytree
1113
from float8_experimental.float8_dynamic_utils import cast_to_float8_e4m3_dynamic
12-
1314
from float8_experimental.float8_tensor import (
1415
Float8Tensor,
1516
merge_mm_configs,
1617
ScaledMMConfig,
1718
)
19+
from float8_experimental.float8_utils import e4m3_dtype
1820
from torch._prims_common import suggest_memory_format
1921

2022
# FSDP pads its local tensor on dim-0. The subclass should be preserved such
@@ -110,3 +112,181 @@ def fsdp_post_all_gather(
110112
out._scale = scale
111113
return
112114
return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,)
115+
116+
117+
class WeightWithDelayedFloat8CastTensor(torch.Tensor):
118+
@staticmethod
119+
def __new__(
120+
cls,
121+
tensor: torch.Tensor,
122+
amax_buffer: torch.Tensor,
123+
amax_history_buffer: torch.Tensor,
124+
scale_buffer: torch.Tensor,
125+
mm_config: ScaledMMConfig,
126+
is_amax_initialized: bool,
127+
):
128+
return torch.Tensor._make_wrapper_subclass(
129+
cls,
130+
tensor.size(),
131+
strides=tensor.stride(),
132+
storage_offset=tensor.storage_offset(),
133+
memory_format=suggest_memory_format(tensor),
134+
dtype=tensor.dtype,
135+
layout=tensor.layout,
136+
device=tensor.device,
137+
pin_memory=tensor.is_pinned(),
138+
requires_grad=tensor.requires_grad,
139+
)
140+
141+
def __init__(
142+
self,
143+
tensor: torch.Tensor,
144+
amax_buffer: torch.Tensor,
145+
amax_history_buffer: torch.Tensor,
146+
scale_buffer: torch.Tensor,
147+
mm_config: ScaledMMConfig,
148+
is_amax_initialized: bool,
149+
):
150+
self._tensor = tensor
151+
self._amax_buffer = amax_buffer
152+
self._amax_history_buffer = amax_history_buffer
153+
self._scale_buffer = scale_buffer
154+
self._mm_config = mm_config
155+
156+
# Note: is_amax_initialized is not a buffer to avoid data dependent
157+
# control flow visible to dynamo
158+
# TODO(future PR): add serialization for this flag
159+
self.is_amax_initialized = is_amax_initialized
160+
161+
@classmethod
162+
def __torch_dispatch__(cls, func, types, args, kwargs=None):
163+
if func == torch.ops.aten.detach.default:
164+
return WeightWithDelayedFloat8CastTensor(
165+
args[0]._tensor,
166+
args[0]._amax_buffer,
167+
args[0]._amax_history_buffer,
168+
args[0]._scale_buffer,
169+
args[0]._mm_config,
170+
args[0].is_amax_initialized,
171+
)
172+
mm_config: Optional[ScaledMMConfig] = None
173+
amax_buffer: Optional[torch.Tensor] = None
174+
amax_history_buffer: Optional[torch.Tensor] = None
175+
scale_buffer: Optional[torch.Tensor] = None
176+
is_amax_initialized: Optional[bool] = None
177+
178+
def unwrap(t):
179+
nonlocal mm_config
180+
if mm_config is None:
181+
mm_config = t._mm_config
182+
else:
183+
mm_config = merge_mm_configs(mm_config, t._mm_config)
184+
nonlocal amax_buffer
185+
if amax_buffer is None:
186+
amax_buffer = t._amax_buffer
187+
nonlocal amax_history_buffer
188+
if amax_history_buffer is None:
189+
amax_history_buffer = t._amax_history_buffer
190+
nonlocal scale_buffer
191+
if scale_buffer is None:
192+
scale_buffer = t._scale_buffer
193+
nonlocal is_amax_initialized
194+
if is_amax_initialized is None:
195+
is_amax_initialized = t.is_amax_initialized
196+
return t._tensor
197+
198+
args, kwargs = pytree.tree_map_only(
199+
WeightWithDelayedFloat8CastTensor, unwrap, (args, kwargs or {})
200+
)
201+
out = func(*args, **kwargs)
202+
if func not in _ops_to_preserve_subclass:
203+
return out
204+
return pytree.tree_map_only(
205+
torch.Tensor,
206+
lambda x: WeightWithDelayedFloat8CastTensor(
207+
x,
208+
amax_buffer,
209+
amax_history_buffer,
210+
scale_buffer,
211+
mm_config,
212+
is_amax_initialized,
213+
),
214+
out,
215+
)
216+
217+
def __tensor_flatten__(self):
218+
return (
219+
[
220+
"_tensor",
221+
"_amax_buffer",
222+
"_amax_history_buffer",
223+
"_scale_buffer",
224+
],
225+
self._mm_config,
226+
is_amax_initialized,
227+
)
228+
229+
@staticmethod
230+
def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
231+
mm_config, is_amax_initialized = flatten_spec
232+
return WeightWithDelayedFloat8CastTensor(
233+
inner_tensors["_tensor"],
234+
inner_tensors["_amax_buffer"],
235+
inner_tensors["_amax_history_buffer"],
236+
inner_tensors["_scale_buffer"],
237+
mm_config,
238+
is_amax_initialized,
239+
)
240+
241+
def __repr__(self):
242+
return f"WeightWithDelayedFloat8CastTensor(tensor={self._tensor}, amax_buffer={self._amax_buffer}, scale_buffer={self._scale_buffer}, mm_config={self._mm_config})"
243+
244+
def fsdp_pre_all_gather(self, mesh):
245+
# initialize if needed
246+
# TODO(before land): ensure settings are consistent between Float8Linear and here
247+
if not self.is_amax_initialized:
248+
from float8_experimental.float8_linear import (
249+
_maybe_initialize_amaxes_scales_for_float8_cast,
250+
)
251+
252+
_maybe_initialize_amaxes_scales_for_float8_cast(
253+
self._tensor,
254+
self._amax_buffer,
255+
self._amax_history_buffer,
256+
self._scale_buffer,
257+
"max", # TODO(before land): read this from parent
258+
e4m3_dtype,
259+
self.is_amax_initialized,
260+
reduce_amax=True,
261+
)
262+
self.is_amax_initialized = True
263+
264+
# this will:
265+
# 1. cast the tensor to float8 using `_scale_buffer`
266+
# 2. populate `_amax_buffer` inplace
267+
# TODO(future PR): clean up all the casting functions and clearly
268+
# separate dynamic vs delayed, tech debt has accumulated
269+
float8_tensor = Float8Tensor.to_float8(
270+
self._tensor,
271+
self._scale_buffer,
272+
e4m3_dtype,
273+
self._amax_buffer,
274+
self._mm_config,
275+
)
276+
return (float8_tensor._data,), (float8_tensor._scale,)
277+
278+
def fsdp_post_all_gather(
279+
self,
280+
all_gather_outputs: Tuple[torch.Tensor, ...],
281+
metadata: Any,
282+
param_dtype: torch.dtype,
283+
*,
284+
out: Optional[torch.Tensor] = None,
285+
):
286+
(data,) = all_gather_outputs
287+
(scale,) = metadata
288+
if out is not None:
289+
assert isinstance(out, Float8Tensor), f"{type(out)}"
290+
out._scale = scale
291+
return
292+
return Float8Tensor(data, scale, param_dtype, self._mm_config), (data,)

test/test_fsdp2/test_fsdp2_common.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
import torch
77
import torch.distributed as dist
88
import torch.nn as nn
9-
from float8_experimental.float8_linear import Float8Linear
9+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
10+
from float8_experimental.float8_linear_utils import (
11+
linear_requires_sync,
12+
sync_float8_amax_and_scale_history,
13+
)
1014

1115

1216
def check_parity_no_mp(
@@ -16,6 +20,7 @@ def check_parity_no_mp(
1620
fsdp_model: nn.Module,
1721
fsdp_optim: torch.optim.Optimizer,
1822
local_inp: torch.Tensor,
23+
scaling_type_w: TensorScalingType = TensorScalingType.DYNAMIC,
1924
):
2025
for iter_idx in range(10):
2126
losses: List[torch.Tensor] = []
@@ -27,8 +32,12 @@ def check_parity_no_mp(
2732
for param in model.parameters():
2833
dist.all_reduce(param.grad)
2934
param.grad.div_(dist.get_world_size())
30-
# TODO(future): add amax syncing once delayed scaling is supported
35+
36+
if linear_requires_sync(scaling_type_w=scaling_type_w):
37+
sync_float8_amax_and_scale_history(model)
38+
3139
optim.step()
40+
3241
test_cls.assertEqual(losses[0], losses[1])
3342

3443

0 commit comments

Comments
 (0)