Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

delete Float8DynamicLinear #304

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 15 additions & 50 deletions benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@

import torch
import torch.utils.benchmark as benchmark
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import (
get_float8_linear,
linear_requires_sync,
LinearType,
sync_float8_amax_and_scale_history,
)
from float8_experimental.float8_tensor import ScaledMMConfig
Expand Down Expand Up @@ -69,7 +67,6 @@ class Experiment:
dtype: torch.dtype
compiled: bool
use_fast_accum: bool
linear_type: str
scaling_repr: str

# 3 Times since we are calculating forward backward
Expand Down Expand Up @@ -98,7 +95,6 @@ def main(
n_limit: Optional[int] = None,
fast_accum_filter: Optional[bool] = None,
shape_name_filter: Optional[str] = None,
linear_type_filter: Optional[str] = None,
scaling_type_x: str = "delayed",
scaling_type_w: str = "delayed",
scaling_type_dL_dY: str = "delayed",
Expand All @@ -123,44 +119,28 @@ def main(
use_fast_accum = [fast_accum_filter]
else:
use_fast_accum = [True, False]
if linear_type_filter is not None:
linear_types = [linear_type_filter]
else:
linear_types = ["delayed", "dynamic"]
if shape_name_filter is not None:
k = shape_name_filter
name_to_shapes_70b = {k: name_to_shapes_70b[k]}
experiment_list: List[Experiment] = []
dtype = torch.bfloat16
for idx, (fast_accum, (name, (K, N)), linear_type) in enumerate(
tqdm(list(product(use_fast_accum, name_to_shapes_70b.items(), linear_types)))
for idx, (fast_accum, (name, (K, N))) in enumerate(
tqdm(list(product(use_fast_accum, name_to_shapes_70b.items())))
):
if n_limit is not None and idx >= n_limit:
break
linear_ref = torch.nn.Linear(K, N, bias=input_bias).to(
device=device, dtype=dtype
)
linear_type_enum = (
LinearType.DELAYED if linear_type == "delayed" else LinearType.DYNAMIC
)

if linear_type == "delayed":
linear_float8 = get_float8_linear(
linear_type_enum,
copy.deepcopy(linear_ref),
emulate=False,
scaling_type_x=scaling_type_x,
scaling_type_w=scaling_type_w,
scaling_type_dL_dY=scaling_type_dL_dY,
)
scaling_repr = linear_float8.scaling_repr()
else:
linear_float8 = get_float8_linear(
linear_type_enum,
copy.deepcopy(linear_ref),
emulate=False,
)
scaling_repr = None
linear_float8 = Float8Linear.from_float(
copy.deepcopy(linear_ref),
emulate=False,
scaling_type_x=scaling_type_x,
scaling_type_w=scaling_type_w,
scaling_type_dL_dY=scaling_type_dL_dY,
)
scaling_repr = linear_float8.scaling_repr()

if fast_accum:
linear_float8.forward_config = ScaledMMConfig(False, True, False)
Expand All @@ -172,19 +152,10 @@ def main(
input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True)
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()

if linear_type_enum == LinearType.DELAYED:

def float8_forw_backward():
if linear_requires_sync(
linear_type_enum, scaling_type_x, scaling_type_w, scaling_type_dL_dY
):
sync_float8_amax_and_scale_history(linear_float8)
linear_float8(input_tensor).sum().backward()

else:

def float8_forw_backward():
linear_float8(input_tensor).sum().backward()
def float8_forw_backward():
if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY):
sync_float8_amax_and_scale_history(linear_float8)
linear_float8(input_tensor).sum().backward()

def n_times(n, fn, *args, **kwargs):
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -224,7 +195,6 @@ def wrapper(*args, **kwargs):
dtype,
compile,
use_fast_accum=fast_accum,
linear_type=linear_type,
scaling_repr=scaling_repr,
)
print(experiment)
Expand All @@ -237,7 +207,6 @@ def wrapper(*args, **kwargs):
"M",
"K",
"N",
"linear_type",
"scaling_repr",
"ref_dtype",
"compiled",
Expand All @@ -257,7 +226,6 @@ def wrapper(*args, **kwargs):
experiment.shape[0],
experiment.shape[1],
experiment.shape[2],
experiment.linear_type,
experiment.scaling_repr,
experiment.dtype,
experiment.compiled,
Expand Down Expand Up @@ -287,7 +255,6 @@ def wrapper(*args, **kwargs):
[
"name",
"shape",
"linear_type",
"scaling_repr",
"compiled",
"use_fast_accum",
Expand All @@ -311,7 +278,6 @@ def invoke_main() -> None:
parser.add_argument("-n", "--n_limit", type=int, required=False)
parser.add_argument("--fast_accum_filter", type=bool, required=False)
parser.add_argument("--shape_name_filter", type=str, required=False)
parser.add_argument("--linear_type_filter", type=str, required=False)
parser.add_argument("--scaling_type_x", type=str, required=False)
parser.add_argument("--scaling_type_w", type=str, required=False)
parser.add_argument("--scaling_type_dL_dY", type=str, required=False)
Expand All @@ -330,7 +296,6 @@ def invoke_main() -> None:
args.n_limit,
args.fast_accum_filter,
args.shape_name_filter,
args.linear_type_filter,
**kwargs,
)

Expand Down
10 changes: 8 additions & 2 deletions benchmarks/bench_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch.multiprocessing as mp
import torch.nn as nn
import torch.utils.benchmark as benchmark
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How useful is this benchmark in general?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't used it recently

from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
Expand Down Expand Up @@ -65,7 +65,13 @@ def get_model(K, N, is_fp8, base_dtype=torch.float32):
modules.append(nn.ReLU())
m = nn.Sequential(*modules)
if is_fp8:
swap_linear_with_float8_linear(m, Float8Linear, emulate=False)
swap_linear_with_float8_linear(
m,
emulate=False,
scaling_type_x=TensorScalingType.DELAYED,
scaling_type_w=TensorScalingType.DELAYED,
scaling_type_dL_dY=TensorScalingType.DELAYED,
)
return m


Expand Down
49 changes: 23 additions & 26 deletions benchmarks/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import (
linear_requires_sync,
LinearType,
swap_linear_with_float8_linear,
sync_float8_amax_and_scale_history,
)
Expand Down Expand Up @@ -206,19 +204,25 @@ def profile_function(
def main(
profile_path_prefix: Path,
compile: bool = True,
linear_type: str = "dynamic",
scaling_type_x: str = "delayed",
scaling_type_w: str = "delayed",
scaling_type_dL_dY: str = "delayed",
scaling_type_x: str = "dynamic",
scaling_type_w: str = "dynamic",
scaling_type_dL_dY: str = "dynamic",
model_type: str = "linear",
dtype_filter: str = "both",
):
assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported"
assert dtype_filter in ("both", "float8", "bfloat16")

print(f"Compile is set to | {compile}")
print(f"Using Linear type: | {linear_type}")
print(f"model_type is set to | {model_type}")
scaling_type_x = TensorScalingType(scaling_type_x)
scaling_type_w = TensorScalingType(scaling_type_w)
scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY)
scaling_repr = "_".join(
[s.short_str() for s in (scaling_type_x, scaling_type_w, scaling_type_dL_dY)]
)

print(f"Compile is set to | {compile}")
print(f"model_type is set to | {model_type}")
print(f"scaling_repr is set to | {scaling_repr}")

device = "cuda"
ref_dtype = torch.bfloat16
Expand Down Expand Up @@ -249,21 +253,14 @@ def main(

m_ref = m_ref.to(device).to(ref_dtype)

linear_type = LinearType[linear_type.upper()]
linear_cls = (
Float8Linear if linear_type is LinearType.DELAYED else Float8DynamicLinear
)
extra_kwargs = {}
scaling_type_x = TensorScalingType(scaling_type_x)
scaling_type_w = TensorScalingType(scaling_type_w)
scaling_type_dL_dY = TensorScalingType(scaling_type_dL_dY)
if linear_type is LinearType.DELAYED:
extra_kwargs["scaling_type_x"] = scaling_type_x
extra_kwargs["scaling_type_w"] = scaling_type_w
extra_kwargs["scaling_type_dL_dY"] = scaling_type_dL_dY
extra_kwargs = {
"scaling_type_x": scaling_type_x,
"scaling_type_w": scaling_type_w,
"scaling_type_dL_dY": scaling_type_dL_dY,
}

m_float8 = copy.deepcopy(m_ref)
swap_linear_with_float8_linear(m_float8, linear_cls, **extra_kwargs)
swap_linear_with_float8_linear(m_float8, **extra_kwargs)

def ref_forw_backward(x):
out = m_ref(x)
Expand All @@ -281,9 +278,7 @@ def float8_forw_backward_wrapper(x):
# inspection of the fw+bw torch.compile without the scale
# syncing code
# TODO(future): make this better
if linear_requires_sync(
linear_type, scaling_type_x, scaling_type_w, scaling_type_dL_dY
):
if linear_requires_sync(scaling_type_x, scaling_type_w, scaling_type_dL_dY):
with record_function("scale_amax_and_scales"):
sync_amax_history(m_float8)
out = float8_forw(x)
Expand Down Expand Up @@ -345,7 +340,9 @@ def float8_forw_backward_wrapper(x):
if dtype_filter != "bfloat16":
# Profile Float8 Model
print("profiling float8")
float8_suffix = f"_{model_type}_float8_compile_{compile}_{linear_type}.json"
float8_suffix = (
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
)
float8_path = profile_path_prefix + float8_suffix
profile_config = ProfileConfig(
float8_path,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,64 +53,6 @@ def backward(ctx, gradY):
return fp8_tensor, None


class Float8DynamicLinear(torch.nn.Linear):
"""
A wrapper around a `torch.nn.Linear` module which does fp8 compute. By on the fly
conversion to fp8 of the input and weight tensors.
"""

def __init__(self, **super_kwargs):
super().__init__(**super_kwargs)

def forward(self, input: torch.Tensor) -> torch.Tensor:
x_fp8 = cast_to_float8_e4m3_dynamic(input, self.forward_config)
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config)
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config)
return y

@classmethod
def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear

Args:
mod (torch.nn.Linear): nn.Linear to convert
emulate (bool): whether to emulate fp8 matmul logic in float32
"""
with torch.device("meta"):
super_kwargs = {
"in_features": mod.in_features,
"out_features": mod.out_features,
"bias": False,
}
new_mod = cls(**super_kwargs)

new_mod.forward_config = ScaledMMConfig(
emulate=emulate,
use_fast_accum=not bool(emulate),
fp8_output=False,
pad_inner_dim=config.pad_inner_dim,
)
new_mod.backward_config = ScaledMMConfig(
emulate=emulate,
use_fast_accum=False,
fp8_output=False,
pad_inner_dim=config.pad_inner_dim,
)
if config.enable_fsdp_fp8_all_gather:
new_mod.weight = nn.Parameter(
WeightWithDynamicFloat8CastTensor(mod.weight, new_mod.forward_config)
)
else:
new_mod.weight = mod.weight
new_mod.bias = mod.bias
return new_mod


def cast_to_float8_e4m3_dynamic(
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
) -> Float8Tensor:
Expand Down
6 changes: 3 additions & 3 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import torch

from float8_experimental.float8_dynamic_linear import (
from float8_experimental.float8_dynamic_utils import (
cast_to_float8_e4m3_dynamic,
cast_to_float8_e5m2_dynamic_bw,
WeightWithDynamicFloat8CastTensor,
Expand Down Expand Up @@ -402,8 +402,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

def scaling_repr(self):
# add scaling settings without using too many characters
# example: "x:del,w:del,dldy:dyn"
return f"x:{self.scaling_type_x.short_str()},w:{self.scaling_type_w.short_str()},dldy:{self.scaling_type_dL_dY.short_str()}"
# example: "x_del_w_del_dldy_dyn"
return f"x_{self.scaling_type_x.short_str()}_w_{self.scaling_type_w.short_str()}_dldy_{self.scaling_type_dL_dY.short_str()}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the change out of curiosity? I think the prior version might be a little more readable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should have reverted this. Will follow-up in a future PR if that's ok, to make landing this PR easier.


def extra_repr(self):
s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"'
Expand Down
Loading
Loading