-
Notifications
You must be signed in to change notification settings - Fork 19
[FSDP2] precompute scale after optimizer.step for dynamic scaling #266
Changes from 5 commits
9d5595c
e7005c2
e41d589
e0bee10
c0ba5a2
8da238e
ffff5ed
aefa21b
d4a1db7
d36e79b
6f244a2
dc5eab0
546e979
229ede6
d5b3ff6
4f05e04
3de59af
ffcd197
6b18947
562424c
75e0e45
fe95f8b
1cbaa13
fe2e0a0
e4eaa2a
e4245e4
e12c973
9ef67fb
fa2f08a
ba085e5
ac0afb0
8e56dfc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
tensor_already_casted_to_fp8, | ||
to_fp8_no_autograd, | ||
) | ||
from float8_experimental.float8_utils import tensor_to_scale | ||
from float8_experimental.float8_utils import amax_to_scale, tensor_to_scale | ||
from torch._prims_common import suggest_memory_format | ||
|
||
|
||
|
@@ -144,13 +144,19 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig): | |
dtype=tensor.dtype, | ||
layout=tensor.layout, | ||
device=tensor.device, | ||
pin_memory=tensor.is_pinned(), | ||
# TODO: workaround fake tensor not implementing is.pinned | ||
# pin_memory=tensor.is_pinned(), | ||
pin_memory=False, | ||
requires_grad=tensor.requires_grad, | ||
) | ||
|
||
def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig): | ||
self._tensor = tensor | ||
self._mm_config = mm_config | ||
# Optional cache for pre-computed fp8 data/scale | ||
self._fp8_data: Optional[torch.Tensor] = None | ||
self._fp8_scale: Optional[torch.Tensor] = None | ||
self._fp8_amax: Optional[torch.Tensor] = None | ||
|
||
@classmethod | ||
def __torch_dispatch__(cls, func, types, args, kwargs=None): | ||
|
@@ -190,9 +196,22 @@ def __repr__(self): | |
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})" | ||
|
||
def fsdp_pre_all_gather(self, mesh): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if |
||
float8_tensor = cast_to_float8_e4m3fn( | ||
self._tensor, self._mm_config, reduce_amax=True | ||
) | ||
if self._fp8_data is not None and self._fp8_scale is not None: | ||
return (self._fp8_data,), (self._fp8_scale,) | ||
if self._fp8_amax is not None: | ||
scale = amax_to_scale( | ||
self._fp8_amax, | ||
torch.float8_e4m3fn, | ||
self._fp8_amax.dtype, | ||
clamp_amax=False, | ||
) | ||
float8_tensor = Float8Tensor.to_float8( | ||
self._tensor, scale, torch.float8_e4m3fn, mm_config=self._mm_config | ||
) | ||
else: | ||
float8_tensor = cast_to_float8_e4m3fn( | ||
self._tensor, self._mm_config, reduce_amax=True | ||
) | ||
return (float8_tensor._data,), (float8_tensor._scale,) | ||
|
||
def fsdp_post_all_gather( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,16 +5,25 @@ | |
# LICENSE file in the root directory of this source tree. | ||
import copy | ||
import logging | ||
import warnings | ||
from enum import auto, Enum | ||
from typing import Callable, List, Optional, Type | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear | ||
from float8_experimental.float8_dynamic_linear import ( | ||
Float8DynamicLinear, | ||
WeightWithDynamicFloat8CastTensor, | ||
) | ||
from float8_experimental.float8_linear import Float8Linear | ||
|
||
from float8_experimental.float8_utils import amax_history_to_scale_stack | ||
from float8_experimental.float8_utils import ( | ||
amax_history_to_scale_stack, | ||
E4M3_MAX_POS, | ||
EPS, | ||
to_fp8_saturated, | ||
) | ||
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor | ||
|
||
log = logging.getLogger(__name__) | ||
|
@@ -322,3 +331,79 @@ def inner_func(): | |
for child in fp8_layers: | ||
# Set a flag to signal amaxes/scales are ready | ||
child.amax_and_scale_synced = True | ||
|
||
|
||
def precompute_float8_amax(module: nn.Module) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we put this in I think the function name should include that this is intended for FSDP2 with float8 all-gather There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. moving to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indicating fsdp by renaming to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @weifengpy do you plan / want to use compile on this, and are there any gaps around here that you think would be good to prioritize on the compile side? This is mostly just me remembering @awgu mention a while ago that he thought compile added noticeable runtime overhead, and I can't remember if it was for this specific case. If it is, and we think compiling this code would be useful, I can prioritize looking into the runtime overhead. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @bdhirsh, I plan to polish and land this PR without compile next week to conclude H1. most importantly add Reducing runtime overhead from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you have a mini repro showing bad runtime overheads with compile, that would be great! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @bdhirsh , I have created a repro pytorch/pytorch#129457 . I highlighted extra cpu overhead and gpu time for torch.compile(mode="reduce-overhead") |
||
from torch.distributed._tensor import DTensor | ||
|
||
if any(isinstance(m, Float8Linear) for m in module.modules()): | ||
raise NotImplementedError("Only supports Float8DynamicLinear, not Float8Linear") | ||
float8_linears: List[Float8DynamicLinear] = [ | ||
m | ||
for m in module.modules() | ||
if isinstance(m, Float8DynamicLinear) | ||
and isinstance(m.weight, DTensor) | ||
and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) | ||
] | ||
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] | ||
|
||
def compute_amaxes(weights: List[DTensor]): | ||
abs_weights = torch._foreach_abs(weights) # S0 | ||
amax_tensor = torch.vstack([torch.max(a) for a in abs_weights]) # P | ||
amax_tensor = torch.clamp(amax_tensor, EPS) # R | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So you are relying on If this fragments the code, could we just all-reduce the amax tensor and then leave the clamp to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks for the suggestions. I can collect feedback from float8 folks if they have a preference There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we just comment with what is going on? I think it's fine as long as the code is easy to understand and there is no magic. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agreed |
||
amaxes = torch.split(amax_tensor, 1) # R | ||
return amaxes | ||
|
||
if weights: | ||
# amaxes = compute_amaxes(weights) | ||
# amaxes = torch.compile(compute_amaxes, mode="reduce-overhead")(weights) | ||
amaxes = torch.compile(compute_amaxes)(weights) | ||
for amax, float8_linear in zip(amaxes, float8_linears): | ||
float8_linear.weight._local_tensor._fp8_amax = amax._local_tensor | ||
else: | ||
warnings.warn( | ||
"Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!" | ||
) | ||
|
||
|
||
def precompute_float8_weights(module: nn.Module) -> None: | ||
from torch.distributed._tensor import DTensor | ||
|
||
if any(isinstance(m, Float8Linear) for m in module.modules()): | ||
raise NotImplementedError("Only supports Float8DynamicLinear, not Float8Linear") | ||
float8_linears: List[Float8DynamicLinear] = [ | ||
m | ||
for m in module.modules() | ||
if isinstance(m, Float8DynamicLinear) | ||
and isinstance(m.weight, DTensor) | ||
and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor) | ||
] | ||
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears] | ||
|
||
def compute_weights_and_scales(weights: List[DTensor]): | ||
abs_weights = torch._foreach_abs(weights) # S0 | ||
# abs_weights = [torch.abs(w) for w in weights] | ||
amax_tensor = torch.vstack([torch.max(a) for a in abs_weights]) # P | ||
amax_tensor = torch.clamp(amax_tensor, EPS) # R | ||
scales_tensor = E4M3_MAX_POS / amax_tensor # R | ||
scales = torch.split(scales_tensor, 1) # R | ||
weights_scaled = torch._foreach_mul(weights, scales) # S0 | ||
datas = [to_fp8_saturated(w, torch.float8_e4m3fn) for w in weights_scaled] # S0 | ||
# torch._foreach_clamp_min_(weights_scaled, -1 * E4M3_MAX_POS) | ||
# torch._foreach_clamp_max_(weights_scaled, E4M3_MAX_POS) | ||
# datas = [w.to(torch.float8_e4m3fn) for w in weights_scaled] | ||
return datas, scales | ||
|
||
if weights: | ||
# datas, scales = compute_weights_and_scales(weights) | ||
datas, scales = torch.compile(compute_weights_and_scales)(weights) | ||
# datas, scales = torch.compile(compute_weights_and_scales, mode="reduce-overhead")(weights) | ||
for data, scale, float8_linear in zip(datas, scales, float8_linears): | ||
float8_linear.weight._local_tensor._fp8_data = data._local_tensor | ||
float8_linear.weight._local_tensor._fp8_scale = ( | ||
scale._local_tensor.squeeze() | ||
) | ||
else: | ||
warnings.warn( | ||
"Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!" | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,12 +24,18 @@ | |
|
||
|
||
@torch.no_grad() | ||
def amax_to_scale(amax, float8_dtype, orig_dtype): | ||
def amax_to_scale(amax, float8_dtype, orig_dtype, clamp_amax=True): | ||
scale = torch.empty_like(amax, dtype=torch.float32) | ||
if float8_dtype == torch.float8_e4m3fn: | ||
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) | ||
if clamp_amax: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I think if you have this on a seperate line makes the logic a lil easier to follow |
||
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) | ||
else: | ||
res = E4M3_MAX_POS / amax | ||
else: # e5m2 | ||
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS) | ||
if clamp_amax: | ||
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS) | ||
else: | ||
res = E5M2_MAX_POS / amax | ||
|
||
# Ensure that the scale is representable in float16, | ||
# this helps when amax is small. We are assuming that we don't need | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One major requirement for tensor subclasses that I don't think is respected here:
__tensor_flatten__
and__tensor_unflatten__
must properly convey every inner tensor on the subclass.So when we call
__tensor_flatten__
on this subclass, if either of_fp8_data/scale/amax
are set to valid tensors, they need to be returned there (and similarly__tensor_unflatten__
needs to handle them as extra args)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for pointing this out! This saves me a lot of debugging time. I can give it a try by including
_fp8_data/scale/amax
in__tensor_flatten__
and__tensor_unflatten__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.compile
works after patching pytorch/pytorch#127431will compare traces in 2nd PR