Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

[FSDP2] precompute scale after optimizer.step for dynamic scaling #266

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9d5595c
[FSDP2] set vocab_size=32 to avoid must be divisible by 16 error
weifengpy May 21, 2024
e7005c2
precast after optimizer.step and dump profiler traces
weifengpy May 21, 2024
e41d589
Merge branch 'main' into fsdp2
weifengpy May 21, 2024
e0bee10
precast and preamax unit test
weifengpy May 24, 2024
c0ba5a2
remove duplicate vocab
weifengpy May 24, 2024
8da238e
fused amax
weifengpy May 30, 2024
ffff5ed
Merge branch 'main' into fsdp2
weifengpy Jun 6, 2024
aefa21b
use FP8_TYPES and max
weifengpy Jun 6, 2024
d4a1db7
commit all changes before cleaning
weifengpy Jun 6, 2024
d36e79b
pre_compute and flatten / unflatten
weifengpy Jun 6, 2024
6f244a2
remove unused constant
weifengpy Jun 6, 2024
dc5eab0
torch.compile works
weifengpy Jun 6, 2024
546e979
eager ready
weifengpy Jun 6, 2024
229ede6
linter
weifengpy Jun 6, 2024
d5b3ff6
linter
weifengpy Jun 6, 2024
4f05e04
flatten tensor
weifengpy Jun 25, 2024
3de59af
commit all changes for review before rebasing
weifengpy Jul 8, 2024
ffcd197
rebase on unified float8linear
weifengpy Jul 9, 2024
6b18947
Merge branch 'pytorch-labs:main' into fsdp2
weifengpy Jul 9, 2024
562424c
move precompute to fsdp_utils.py
weifengpy Jul 9, 2024
75e0e45
simplify amax calc
weifengpy Jul 9, 2024
fe95f8b
explain _pre_computed_amax
weifengpy Jul 9, 2024
1cbaa13
fix linter
weifengpy Jul 9, 2024
fe2e0a0
document precompute_float8_amax_for_fsdp
weifengpy Jul 9, 2024
e4eaa2a
rename pre_compute to precompute
weifengpy Jul 9, 2024
e4245e4
Merge branch 'main' into fsdp2
weifengpy Jul 10, 2024
e12c973
remove clamp_amax=True/False
weifengpy Jul 10, 2024
9ef67fb
precompute scale
weifengpy Jul 10, 2024
fa2f08a
unit test for precomputing scales
weifengpy Jul 10, 2024
ba085e5
add precompute scale in README
weifengpy Jul 10, 2024
ac0afb0
rename to precompute_float8_dynamic_scale_for_fsdp
weifengpy Jul 11, 2024
8e56dfc
rename to precompute_float8_dynamic_scale_for_fsdp
weifengpy Jul 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
tensor_already_casted_to_fp8,
to_fp8_no_autograd,
)
from float8_experimental.float8_utils import tensor_to_scale
from float8_experimental.float8_utils import amax_to_scale, tensor_to_scale
from torch._prims_common import suggest_memory_format


Expand Down Expand Up @@ -144,13 +144,19 @@ def __new__(cls, tensor: torch.Tensor, mm_config: ScaledMMConfig):
dtype=tensor.dtype,
layout=tensor.layout,
device=tensor.device,
pin_memory=tensor.is_pinned(),
# TODO: workaround fake tensor not implementing is.pinned
# pin_memory=tensor.is_pinned(),
pin_memory=False,
requires_grad=tensor.requires_grad,
)

def __init__(self, tensor: torch.Tensor, mm_config: ScaledMMConfig):
self._tensor = tensor
self._mm_config = mm_config
# Optional cache for pre-computed fp8 data/scale
self._fp8_data: Optional[torch.Tensor] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One major requirement for tensor subclasses that I don't think is respected here: __tensor_flatten__ and __tensor_unflatten__ must properly convey every inner tensor on the subclass.

So when we call __tensor_flatten__ on this subclass, if either of _fp8_data/scale/amax are set to valid tensors, they need to be returned there (and similarly __tensor_unflatten__ needs to handle them as extra args)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for pointing this out! This saves me a lot of debugging time. I can give it a try by including _fp8_data/scale/amax in __tensor_flatten__ and __tensor_unflatten__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.compile works after patching pytorch/pytorch#127431
will compare traces in 2nd PR

self._fp8_scale: Optional[torch.Tensor] = None
self._fp8_amax: Optional[torch.Tensor] = None

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
Expand Down Expand Up @@ -190,9 +196,22 @@ def __repr__(self):
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"

def fsdp_pre_all_gather(self, mesh):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if _pre_computed_amax, we skip tensor_to_amax and directly do amax_to_scale

float8_tensor = cast_to_float8_e4m3fn(
self._tensor, self._mm_config, reduce_amax=True
)
if self._fp8_data is not None and self._fp8_scale is not None:
return (self._fp8_data,), (self._fp8_scale,)
if self._fp8_amax is not None:
scale = amax_to_scale(
self._fp8_amax,
torch.float8_e4m3fn,
self._fp8_amax.dtype,
clamp_amax=False,
)
float8_tensor = Float8Tensor.to_float8(
self._tensor, scale, torch.float8_e4m3fn, mm_config=self._mm_config
)
else:
float8_tensor = cast_to_float8_e4m3fn(
self._tensor, self._mm_config, reduce_amax=True
)
return (float8_tensor._data,), (float8_tensor._scale,)

def fsdp_post_all_gather(
Expand Down
89 changes: 87 additions & 2 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,25 @@
# LICENSE file in the root directory of this source tree.
import copy
import logging
import warnings
from enum import auto, Enum
from typing import Callable, List, Optional, Type

import torch
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_dynamic_linear import (
Float8DynamicLinear,
WeightWithDynamicFloat8CastTensor,
)
from float8_experimental.float8_linear import Float8Linear

from float8_experimental.float8_utils import amax_history_to_scale_stack
from float8_experimental.float8_utils import (
amax_history_to_scale_stack,
E4M3_MAX_POS,
EPS,
to_fp8_saturated,
)
from torch.distributed._functional_collectives import all_reduce, AsyncCollectiveTensor

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -322,3 +331,79 @@ def inner_func():
for child in fp8_layers:
# Set a flag to signal amaxes/scales are ready
child.amax_and_scale_synced = True


def precompute_float8_amax(module: nn.Module) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we put this in distributed_utils.py?

I think the function name should include that this is intended for FSDP2 with float8 all-gather

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moving to fsdp_utils.py according to PR #310

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indicating fsdp by renaming to precompute_float8_amax_for_fsdp

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weifengpy do you plan / want to use compile on this, and are there any gaps around here that you think would be good to prioritize on the compile side?

This is mostly just me remembering @awgu mention a while ago that he thought compile added noticeable runtime overhead, and I can't remember if it was for this specific case. If it is, and we think compiling this code would be useful, I can prioritize looking into the runtime overhead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bdhirsh, I plan to polish and land this PR without compile next week to conclude H1. most importantly add _pre_computed_amax to flatten/unflatten

Reducing runtime overhead from torch.compile is still meaningful since we want torch.compile(fp8 casting) in FSDP2 pre-forward hooks. would it be helpful if I work on a mini repro with profiler traces? Want to unblock you in the short-term

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have a mini repro showing bad runtime overheads with compile, that would be great!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bdhirsh , I have created a repro pytorch/pytorch#129457 . I highlighted extra cpu overhead and gpu time for torch.compile(mode="reduce-overhead")

from torch.distributed._tensor import DTensor

if any(isinstance(m, Float8Linear) for m in module.modules()):
raise NotImplementedError("Only supports Float8DynamicLinear, not Float8Linear")
float8_linears: List[Float8DynamicLinear] = [
m
for m in module.modules()
if isinstance(m, Float8DynamicLinear)
and isinstance(m.weight, DTensor)
and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor)
]
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears]

def compute_amaxes(weights: List[DTensor]):
abs_weights = torch._foreach_abs(weights) # S0
amax_tensor = torch.vstack([torch.max(a) for a in abs_weights]) # P
amax_tensor = torch.clamp(amax_tensor, EPS) # R
Copy link
Contributor Author

@weifengpy weifengpy Jun 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.clamp calls all_reduce. I avoided calling it again in amax_to_scale(clamp_amax=False)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you are relying on torch.clamp to run the all-reduce implicitly from changing sharding from partial to replicate?

If this fragments the code, could we just all-reduce the amax tensor and then leave the clamp to amax_to_scale? I agree the current way is faster since we are doing one clamp for all amaxes, but in case float8 folks are not happy with this fragmentation, this seems like another way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the suggestions. I can collect feedback from float8 folks if they have a preference

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just comment with what is going on? I think it's fine as long as the code is easy to understand and there is no magic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed

amaxes = torch.split(amax_tensor, 1) # R
return amaxes

if weights:
# amaxes = compute_amaxes(weights)
# amaxes = torch.compile(compute_amaxes, mode="reduce-overhead")(weights)
amaxes = torch.compile(compute_amaxes)(weights)
for amax, float8_linear in zip(amaxes, float8_linears):
float8_linear.weight._local_tensor._fp8_amax = amax._local_tensor
else:
warnings.warn(
"Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!"
)


def precompute_float8_weights(module: nn.Module) -> None:
from torch.distributed._tensor import DTensor

if any(isinstance(m, Float8Linear) for m in module.modules()):
raise NotImplementedError("Only supports Float8DynamicLinear, not Float8Linear")
float8_linears: List[Float8DynamicLinear] = [
m
for m in module.modules()
if isinstance(m, Float8DynamicLinear)
and isinstance(m.weight, DTensor)
and isinstance(m.weight._local_tensor, WeightWithDynamicFloat8CastTensor)
]
weights: List[DTensor] = [float8_linear.weight for float8_linear in float8_linears]

def compute_weights_and_scales(weights: List[DTensor]):
abs_weights = torch._foreach_abs(weights) # S0
# abs_weights = [torch.abs(w) for w in weights]
amax_tensor = torch.vstack([torch.max(a) for a in abs_weights]) # P
amax_tensor = torch.clamp(amax_tensor, EPS) # R
scales_tensor = E4M3_MAX_POS / amax_tensor # R
scales = torch.split(scales_tensor, 1) # R
weights_scaled = torch._foreach_mul(weights, scales) # S0
datas = [to_fp8_saturated(w, torch.float8_e4m3fn) for w in weights_scaled] # S0
# torch._foreach_clamp_min_(weights_scaled, -1 * E4M3_MAX_POS)
# torch._foreach_clamp_max_(weights_scaled, E4M3_MAX_POS)
# datas = [w.to(torch.float8_e4m3fn) for w in weights_scaled]
return datas, scales

if weights:
# datas, scales = compute_weights_and_scales(weights)
datas, scales = torch.compile(compute_weights_and_scales)(weights)
# datas, scales = torch.compile(compute_weights_and_scales, mode="reduce-overhead")(weights)
for data, scale, float8_linear in zip(datas, scales, float8_linears):
float8_linear.weight._local_tensor._fp8_data = data._local_tensor
float8_linear.weight._local_tensor._fp8_scale = (
scale._local_tensor.squeeze()
)
else:
warnings.warn(
"Calling precompute_float8_weights without any weights using FSDP fp8 all-gather!"
)
12 changes: 9 additions & 3 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,18 @@


@torch.no_grad()
def amax_to_scale(amax, float8_dtype, orig_dtype):
def amax_to_scale(amax, float8_dtype, orig_dtype, clamp_amax=True):
scale = torch.empty_like(amax, dtype=torch.float32)
if float8_dtype == torch.float8_e4m3fn:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
if clamp_amax:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think if you have this on a seperate line
amax = clamp(amax, eps) if clamp_amax else amax

makes the logic a lil easier to follow

res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
else:
res = E4M3_MAX_POS / amax
else: # e5m2
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
if clamp_amax:
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
else:
res = E5M2_MAX_POS / amax

# Ensure that the scale is representable in float16,
# this helps when amax is small. We are assuming that we don't need
Expand Down
83 changes: 70 additions & 13 deletions test/test_fsdp2/test_fsdp2_eager.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import threading
import unittest
from typing import Any, List
from typing import Any, List, Union

import torch
import torch._dynamo.testing
Expand All @@ -11,7 +11,11 @@
Float8DynamicLinear,
WeightWithDynamicFloat8CastTensor,
)
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from float8_experimental.float8_linear_utils import (
precompute_float8_amax,
precompute_float8_weights,
swap_linear_with_float8_linear,
)
from test_fsdp2_common import (
check_parity_bf16_mp,
check_parity_no_mp,
Expand Down Expand Up @@ -57,12 +61,13 @@ def init_multi_module(self) -> nn.Module:
def init_transformer(self, weight_tying: bool) -> nn.Module:
torch.manual_seed(42)
args = ModelArgs(
n_layers=3,
dim=768,
n_heads=12,
n_layers=8,
dim=4096,
n_heads=32,
dropout_p=0.0,
weight_tying=weight_tying,
vocab_size=32,
vocab_size=4096,
max_seq_len=4096,
)
module = Transformer(args).cuda()
self.broadcast_module(module)
Expand All @@ -78,17 +83,55 @@ def swap_linear_with_dynamic(self, module: nn.Module, **kwargs: Any) -> nn.Modul
return swap_linear_with_float8_linear(module, Float8DynamicLinear, **kwargs)


def profiler(output_dir):
"""
Utility component that wraps around `torch.profiler` to profile model's operators.
See https://pytorch.org/docs/stable/profiler.html for more details.
The schedule for this profiler is wait 100 steps, warmup 5 steps, trace 5 steps
Note: Enabling pytorch profiler may have training speed reduction.

Args:
enabled (Optional[bool]): Enable pytorch profiler. Default is False.
output_dir (Optional[str]): Tracing file output path. Default is "./torchtune_perf_tracing.json".

Returns:
ContextManager: pytorch profiler context manager
"""

def trace_handler(prof) -> None:
prof.export_chrome_trace(output_dir)

return torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=0, warmup=1, active=2, repeat=1, skip_first=1
),
on_trace_ready=trace_handler,
record_shapes=True,
profile_memory=False,
with_stack=False,
)


class TestFloat8MultiProcess(FSDPTest, TestFloat8Common):
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 2)

@skip_if_lt_x_gpu(2)
def test_transformer_parity_dynamic(self):
for enable_fsdp_fp8_all_gather in [False, True]:
self._test_transformer_parity_dynamic(enable_fsdp_fp8_all_gather)
for enable_fsdp_fp8_all_gather in [True]:
for pre_compute in [None, "cast", "amax"]:
self._test_transformer_parity_dynamic(
enable_fsdp_fp8_all_gather, pre_compute
)

def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
def _test_transformer_parity_dynamic(
self, enable_fsdp_fp8_all_gather: bool, pre_compute: Union[str, None]
):
# NOTE: Weight-tying does not compose with fp8 all-gather because the
# embedding weight and output linear weight are tied but only the
# latter uses fp8 compute. With fp8 all-gather, FSDP would pre-cast to
Expand All @@ -106,11 +149,25 @@ def _test_transformer_parity_dynamic(self, enable_fsdp_fp8_all_gather: bool):
ref_optim = torch.optim.Adam(ref_module.parameters(), lr=1e-2)
optim = torch.optim.Adam(module.parameters(), lr=1e-2, foreach=True)
local_inp = torch.randint(
0, ref_module.tok_embeddings.weight.size(0), (16, 16), device="cuda"
)
check_parity_no_mp(
self, ref_module, ref_optim, module, optim, local_inp, Float8DynamicLinear
0, ref_module.tok_embeddings.weight.size(0), (4, 512), device="cuda"
)
with profiler(
output_dir=f"./test_fsdp2_eager_fp8_{enable_fsdp_fp8_all_gather}_{pre_compute}_rank_{torch.distributed.get_rank()}.json"
) as prof:
for i in range(5):
optim.zero_grad()
loss = module(local_inp).sum()
# if torch.distributed.get_rank() == 0:
# print(f"{pre_compute=} {i=} {loss=}")
loss.backward()
optim.step()
if pre_compute is None:
pass
elif pre_compute == "cast":
precompute_float8_weights(module)
elif pre_compute == "amax":
precompute_float8_amax(module)
prof.step()

@skip_if_lt_x_gpu(2)
def test_transformer_memory(self):
Expand Down
Loading