Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

[1/x]: Make Float8Linear support dynamic scaling #290

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,13 @@ def __init__(self, **super_kwargs):
super().__init__(**super_kwargs)

def forward(self, input: torch.Tensor) -> torch.Tensor:
x_fp8 = cast_to_float8_e4m3fn(input, self.forward_config)
x_fp8 = cast_to_float8_e4m3_dynamic(input, self.forward_config)
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
w_fp8 = cast_to_float8_e4m3fn(self.weight, self.forward_config)
w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config)
y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)
y = cast_to_float8_e5m2_bw(y, self.backward_config)
y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config)
return y

@classmethod
Expand Down Expand Up @@ -111,7 +111,7 @@ def from_float(cls, mod, emulate: bool = False) -> "Float8DynamicLinear":
return new_mod


def cast_to_float8_e4m3fn(
def cast_to_float8_e4m3_dynamic(
inpt_tensor: torch.Tensor, mm_config: ScaledMMConfig, reduce_amax: bool = False
) -> Float8Tensor:
if tensor_already_casted_to_fp8(inpt_tensor):
Expand All @@ -120,7 +120,7 @@ def cast_to_float8_e4m3fn(
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)


def cast_to_float8_e5m2_bw(
def cast_to_float8_e5m2_dynamic_bw(
gradY: torch.Tensor, mm_config: ScaledMMConfig
) -> torch.Tensor:
return NoopFwToFloat8E5M2Bw.apply(gradY, mm_config)
Expand Down Expand Up @@ -199,7 +199,7 @@ def __repr__(self):
return f"WeightWithDynamicFloat8CastTensor(tensor={self._tensor}, mm_config={self._mm_config})"

def fsdp_pre_all_gather(self, mesh):
float8_tensor = cast_to_float8_e4m3fn(
float8_tensor = cast_to_float8_e4m3_dynamic(
self._tensor, self._mm_config, reduce_amax=True
)
return (float8_tensor._data,), (float8_tensor._scale,)
Expand Down
244 changes: 167 additions & 77 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,19 @@
"""

import dataclasses
import enum

from typing import Optional

import float8_experimental.config as config

import torch

from float8_experimental.float8_dynamic_linear import (
cast_to_float8_e4m3_dynamic,
cast_to_float8_e5m2_dynamic_bw,
)

from float8_experimental.float8_tensor import (
Float8Tensor,
ScaledMMConfig,
Expand Down Expand Up @@ -125,20 +131,54 @@ def __init__(self, history_len: int = 16, scale_fn_name: str = "max"):
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."


class TensorScalingType(enum.Enum):
DELAYED = "delayed"
DYNAMIC = "dynamic"

def short_str(self):
if self is TensorScalingType.DELAYED:
return "del"
else:
assert self is TensorScalingType.DYNAMIC
return "dyn"


class Float8Linear(torch.nn.Linear):
"""
A wrapper around a `torch.nn.Linear` module which does fp8 compute, and tracks
scales in way friendly to delayed scaling.
"""

def __init__(self, *args, **kwargs):
"""
Additional arguments on top of `torch.nn.Linear`'s arguments:
* `delayed_scaling_recipe`: configuration for delayed scaling
* `scaling_type_x`: delayed vs dynamic scaling for `x`
* `scaling_type_w`: delayed vs dynamic scaling for `w`
* `scaling_type_dL_dY`: delayed vs dynamic scaling for `dL_dY`
"""

delayed_scaling_recipe = kwargs.pop(
"delayed_scaling_recipe", DelayedScalingRecipe()
)
# Amax scales should always be kept as float32.
self.always_float32_buffers = set()
scaling_type_x = kwargs.pop("scaling_type_x", TensorScalingType.DELAYED)
scaling_type_w = kwargs.pop("scaling_type_w", TensorScalingType.DELAYED)
scaling_type_dL_dY = kwargs.pop("scaling_type_dL_dY", TensorScalingType.DELAYED)
super().__init__(*args, **kwargs)

# Defines the scaling behavior of x, w, dL_dY
self.scaling_type_x = scaling_type_x
self.scaling_type_w = scaling_type_w
self.scaling_type_dL_dY = scaling_type_dL_dY
# Convenience flag to skip code related to delayed scaling
self.has_any_delayed_scaling = (
self.scaling_type_x is TensorScalingType.DELAYED
or self.scaling_type_w is TensorScalingType.DELAYED
or self.scaling_type_dL_dY is TensorScalingType.DELAYED
)

# TODO(future): have a unique recipe per buffer instead of one per
# module, saving implementing that until we need it.
# TODO(future): serialization for recipes
Expand Down Expand Up @@ -175,37 +215,44 @@ def create_buffers(self):
# Default values for history buffers, see above TODO
history_len = self.recipe.history_len
device = self.weight.device
# TODO(future PR): dtype values below don't have the other float8
# flavors, fix it
default_x = torch.finfo(torch.float8_e4m3fn).max
default_w = torch.finfo(torch.float8_e4m3fn).max
default_dl_dy = torch.finfo(torch.float8_e5m2).max

self.register_always_float32_buffer(
"fp8_amax_x", torch.tensor([default_x], device=device)
)
self.register_always_float32_buffer(
"fp8_amax_history_x", torch.zeros(history_len, device=device)
)
self.register_always_float32_buffer(
"fp8_scale_x", torch.tensor([1.0], device=device)
)
self.register_always_float32_buffer(
"fp8_amax_w", torch.tensor([default_w], device=device)
)
self.register_always_float32_buffer(
"fp8_amax_history_w", torch.zeros(history_len, device=device)
)
self.register_always_float32_buffer(
"fp8_scale_w", torch.tensor([1.0], device=device)
)
self.register_always_float32_buffer(
"fp8_amax_dL_dY", torch.tensor([default_dl_dy], device=device)
)
self.register_always_float32_buffer(
"fp8_amax_history_dL_dY", torch.zeros(history_len, device=device)
)
self.register_always_float32_buffer(
"fp8_scale_dL_dY", torch.tensor([1.0], device=device)
)
# Note: for now, create all the buffers if any are needed, to postpone
# the work to make the scale and amax syncing and history calculation
# handle a heterogeneous setup. We can do that work later if benchmarks
# show it is worth doing.
if self.has_any_delayed_scaling:
self.register_always_float32_buffer(
"fp8_amax_x", torch.tensor([default_x], device=device)
)
self.register_always_float32_buffer(
"fp8_amax_history_x", torch.zeros(history_len, device=device)
)
self.register_always_float32_buffer(
"fp8_scale_x", torch.tensor([1.0], device=device)
)
self.register_always_float32_buffer(
"fp8_amax_w", torch.tensor([default_w], device=device)
)
self.register_always_float32_buffer(
"fp8_amax_history_w", torch.zeros(history_len, device=device)
)
self.register_always_float32_buffer(
"fp8_scale_w", torch.tensor([1.0], device=device)
)
self.register_always_float32_buffer(
"fp8_amax_dL_dY", torch.tensor([default_dl_dy], device=device)
)
self.register_always_float32_buffer(
"fp8_amax_history_dL_dY", torch.zeros(history_len, device=device)
)
self.register_always_float32_buffer(
"fp8_scale_dL_dY", torch.tensor([1.0], device=device)
)

def register_always_float32_buffer(
self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True
Expand Down Expand Up @@ -234,61 +281,77 @@ def cast_x_to_float8(
autocast_dtype = torch.get_autocast_gpu_dtype()
x = x.to(autocast_dtype)

scale_fn_name = self.recipe.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
x,
self.fp8_amax_x,
self.fp8_amax_history_x,
self.fp8_scale_x,
scale_fn_name,
e4m3_dtype,
is_amax_initialized,
reduce_amax=True,
)
x_fp8 = Float8Tensor.to_float8(
x,
self.fp8_scale_x,
e4m3_dtype,
self.fp8_amax_x,
self.forward_config,
)
if self.scaling_type_x is TensorScalingType.DELAYED:
scale_fn_name = self.recipe.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
x,
self.fp8_amax_x,
self.fp8_amax_history_x,
self.fp8_scale_x,
scale_fn_name,
e4m3_dtype,
is_amax_initialized,
reduce_amax=True,
)
x_fp8 = Float8Tensor.to_float8(
x,
self.fp8_scale_x,
e4m3_dtype,
self.fp8_amax_x,
self.forward_config,
)
else:
assert self.scaling_type_x is TensorScalingType.DYNAMIC
x_fp8 = cast_to_float8_e4m3_dynamic(x, self.forward_config)
return x_fp8

def cast_w_to_float8(
self, w: torch.Tensor, is_amax_initialized: bool
) -> torch.Tensor:
scale_fn_name = self.recipe.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
w,
self.fp8_amax_w,
self.fp8_amax_history_w,
self.fp8_scale_w,
scale_fn_name,
e4m3_dtype,
is_amax_initialized,
reduce_amax=False,
)
if self.scaling_type_w is TensorScalingType.DELAYED:
scale_fn_name = self.recipe.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
w,
self.fp8_amax_w,
self.fp8_amax_history_w,
self.fp8_scale_w,
scale_fn_name,
e4m3_dtype,
is_amax_initialized,
reduce_amax=False,
)

w_fp8 = Float8Tensor.to_float8(
w,
self.fp8_scale_w,
e4m3_dtype,
self.fp8_amax_w,
self.forward_config,
)
w_fp8 = Float8Tensor.to_float8(
w,
self.fp8_scale_w,
e4m3_dtype,
self.fp8_amax_w,
self.forward_config,
)
else:
assert self.scaling_type_w is TensorScalingType.DYNAMIC
# TODO(future): also support FSDP integration in delayed scaling path
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
w_fp8 = cast_to_float8_e4m3_dynamic(self.weight, self.forward_config)
return w_fp8

def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
scale_fn_name = self.recipe.scale_fn_name
y = NoopFwToFloat8E5M2Bw.apply(
y,
self.fp8_amax_dL_dY,
self.fp8_amax_history_dL_dY,
self.fp8_scale_dL_dY,
scale_fn_name,
self.is_amax_initialized,
self.backward_config,
)
if self.scaling_type_dL_dY is TensorScalingType.DELAYED:
scale_fn_name = self.recipe.scale_fn_name
y = NoopFwToFloat8E5M2Bw.apply(
y,
self.fp8_amax_dL_dY,
self.fp8_amax_history_dL_dY,
self.fp8_scale_dL_dY,
scale_fn_name,
self.is_amax_initialized,
self.backward_config,
)
else:
assert self.scaling_type_dL_dY is TensorScalingType.DYNAMIC
y = cast_to_float8_e5m2_dynamic_bw(y, self.backward_config)
return y

def float8_pre_forward(self, x):
Expand All @@ -313,7 +376,8 @@ def float8_post_forward(self):
self.amax_and_scale_synced = False

def forward(self, input: torch.Tensor) -> torch.Tensor:
self.float8_pre_forward(input)
if self.has_any_delayed_scaling:
self.float8_pre_forward(input)

x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized)
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
Expand All @@ -326,11 +390,29 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.bias is not None:
y = y + self.bias.to(y.dtype)

self.float8_post_forward()
if self.has_any_delayed_scaling:
self.float8_post_forward()
return y

def extra_repr(self):
# example: in_features=32, out_features=16, bias=True
s = super().extra_repr()
# add scaling settings without using too many characters
scaling = f"x:{self.scaling_type_x.short_str()},w:{self.scaling_type_w.short_str()},dldy:{self.scaling_type_dL_dY.short_str()}"

s = f'{s}, scaling="{scaling}"'
# example: in_features=32, out_features=16, bias=True, scaling="x:del,w:del,dldy:dyn"
return s

@classmethod
def from_float(cls, mod, emulate: bool = False):
def from_float(
cls,
mod,
emulate: bool = False,
scaling_type_x=TensorScalingType.DELAYED,
scaling_type_w=TensorScalingType.DELAYED,
scaling_type_dL_dY=TensorScalingType.DELAYED,
):
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear

Expand All @@ -339,14 +421,22 @@ def from_float(cls, mod, emulate: bool = False):
emulate (bool): whether to emulate fp8 matmul logic in float32
"""
with torch.device("meta"):
new_mod = cls(mod.in_features, mod.out_features, bias=False)
new_mod = cls(
mod.in_features,
mod.out_features,
bias=False,
scaling_type_x=scaling_type_x,
scaling_type_w=scaling_type_w,
scaling_type_dL_dY=scaling_type_dL_dY,
)
new_mod.weight = mod.weight
new_mod.bias = mod.bias
# need to create buffers again when moving from meta device to
# real device
new_mod.create_buffers()
# Defines the behavior of the matmul in the forward and backward
# Forward we use fast_accum, backwards we do not
# TODO(future PR): move below to the constructor
new_mod.forward_config = ScaledMMConfig(
emulate, True if not emulate else False, False, config.pad_inner_dim
)
Expand Down
Loading
Loading