Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

[POC] Added fp8 all-gather extensions #216

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 52 additions & 5 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
"""
A wrapper around a `torch.nn.Linear` module which does fp8 compute.
"""

from typing import Any, cast, Optional, Tuple, Union

import torch
import torch.nn as nn

from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd
from float8_experimental.float8_utils import tensor_to_scale
Expand Down Expand Up @@ -73,7 +77,11 @@ def forward(self, x):
x_fp8 = x if self.use_activation_hooks else self.cast_to_float8_e4m3fn(x)

# cast w to float8_e4m3fn
w_fp8 = self.cast_to_float8_e4m3fn(self.weight)
w_fp8 = (
self.weight
if isinstance(self.weight, Float8Tensor)
else self.cast_to_float8_e4m3fn(self.weight)
)

y = torch.nn.functional.linear(x_fp8, w_fp8, self.bias)

Expand All @@ -83,8 +91,10 @@ def forward(self, x):

return y

def cast_to_float8_e4m3fn(self, inpt_tensor: torch.Tensor) -> Float8Tensor:
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn)
def cast_to_float8_e4m3fn(
self, inpt_tensor: torch.Tensor, reduce_amax: bool = False
) -> Float8Tensor:
scale = tensor_to_scale(inpt_tensor, torch.float8_e4m3fn, reduce_amax)
return Float8Tensor.to_float8(
inpt_tensor, scale, torch.float8_e4m3fn, emulate=self.emulate
)
Expand All @@ -94,7 +104,11 @@ def cast_to_float8_e5m2_bw(self, gradY: torch.Tensor) -> torch.Tensor:

@classmethod
def from_float(
cls, mod, emulate: bool = False, use_activation_hooks: bool = False
cls,
mod: nn.Module,
emulate: bool = False,
use_activation_hooks: bool = False,
use_fp8_all_gather: bool = False,
) -> "Float8DynamicLinear":
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear
Expand All @@ -111,7 +125,11 @@ def from_float(
"bias": False,
}
new_mod = cls(use_activation_hooks, **super_kwargs)
new_mod.weight = mod.weight
new_mod.weight = (
nn.Parameter(Float8DynamicLinearWeightTensor(mod.weight))
if use_fp8_all_gather
else mod.weight
)
new_mod.bias = mod.bias
new_mod.emulate = emulate
if new_mod.use_activation_hooks:
Expand All @@ -121,3 +139,32 @@ def from_float(
cast_grad_to_float8_e5m2_backward_forward_hook
)
return new_mod


class Float8DynamicLinearWeightTensor(torch.Tensor):
# TODO: Remove `module` arg, save state on subclass, and propagate it.
def fsdp_pre_all_gather(
self, module: nn.Module
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
float8_tensor = module.cast_to_float8_e4m3fn(self, reduce_amax=True)
return (float8_tensor._data,), (float8_tensor._scale, module.emulate)

def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[torch.Tensor] = None,
) -> Union[Tuple[Float8Tensor, Tuple[torch.Tensor, ...]], None]:
(data,) = all_gather_outputs
scale, emulate = metadata
if out is not None:
out = cast(Float8Tensor, out)
assert (
data.untyped_storage().data_ptr()
== out._data.untyped_storage().data_ptr()
)
out._scale = scale
return
return Float8Tensor(data, scale, param_dtype, emulate), (data,)
53 changes: 48 additions & 5 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

import dataclasses

from typing import Optional
from typing import Any, cast, Optional, Tuple, Union

import float8_experimental.config as config

import torch
import torch.nn as nn

from float8_experimental.float8_tensor import Float8Tensor

from float8_experimental.float8_utils import (
amax_history_to_scale,
E4M3_MAX_POS,
Expand Down Expand Up @@ -294,7 +294,11 @@ def forward(self, x):
self.float8_pre_forward(x)

x_fp8 = self.cast_x_to_float8(x, self.is_amax_initialized)
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
w_fp8 = (
self.weight
if isinstance(self.weight, Float8Tensor)
else self.cast_w_to_float8(self.weight, self.is_amax_initialized)
)

y = torch.matmul(x_fp8, w_fp8.t())

Expand All @@ -308,7 +312,13 @@ def forward(self, x):
return y

@classmethod
def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = False):
def from_float(
cls,
mod: nn.Module,
emulate: bool = False,
use_activation_hooks: bool = False,
use_fp8_all_gather: bool = False,
):
"""
Create an nn.Linear with fp8 compute from a regular nn.Linear

Expand All @@ -322,9 +332,42 @@ def from_float(cls, mod, emulate: bool = False, use_activation_hooks: bool = Fal
# Tensors and the Linear base to create empty params
# with torch.device("meta"):
new_mod = cls(mod.in_features, mod.out_features, bias=False)
new_mod.weight = mod.weight
new_mod.weight = (
nn.Parameter(Float8LinearWeightTensor(mod.weight))
if use_fp8_all_gather
else mod.weight
)
new_mod.bias = mod.bias
new_mod.emulate = emulate
# I think its okay to send all params and buffers to device
new_mod.to(mod.weight.device)
return new_mod


class Float8LinearWeightTensor(torch.Tensor):
# TODO: Remove `module` arg, save state on subclass, and propagate it.
def fsdp_pre_all_gather(
self, module: nn.Module
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
float8_tensor = module.cast_w_to_float8(self, module.is_amax_initialized)
return (float8_tensor._data,), (float8_tensor._scale, module.emulate)

def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[torch.Tensor] = None,
) -> Union[Tuple[Float8Tensor, Tuple[torch.Tensor, ...]], None]:
(data,) = all_gather_outputs
scale, emulate = metadata
if out is not None:
out = cast(Float8Tensor, out)
assert (
data.untyped_storage().data_ptr()
== out._data.untyped_storage().data_ptr()
)
out._scale = scale
return
return Float8Tensor(data, scale, param_dtype, emulate), (data,)
12 changes: 10 additions & 2 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def swap_linear_with_float8_linear(
skip_fqn_list: Optional[List[str]] = None,
emulate: bool = False,
use_activation_hooks: bool = False,
use_fp8_all_gather: bool = False,
) -> nn.Module:
"""
Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances
Expand All @@ -105,6 +106,7 @@ def swap_linear_with_float8_linear(
Linear submodules of these skipped modules will also be skipped.
emulate (bool): Whether to emulate the fp8 matmul logic in fp32.
use_activation_hooks (bool): Whether to cast activations to fp8 using module hooks.
use_fp8_all_gather (bool): Whether to cast to fp8 before all-gather when using FSDP.
"""
module_names_to_skip = set(skip_fqn_list or [])
if isinstance(module, nn.Linear):
Expand All @@ -113,7 +115,10 @@ def swap_linear_with_float8_linear(
f"Does not support a root nn.Linear with children: {module}"
)
return module_cls.from_float(
module, emulate=emulate, use_activation_hooks=use_activation_hooks
module,
emulate=emulate,
use_activation_hooks=use_activation_hooks,
use_fp8_all_gather=use_fp8_all_gather,
)

# Mark all modules to skip as visited
Expand All @@ -137,7 +142,10 @@ def post_order_traversal(
parent_module is not None
), f"Linear root module should return early: {module}"
float8linear_module = module_cls.from_float(
module, emulate=emulate, use_activation_hooks=use_activation_hooks
module,
emulate=emulate,
use_activation_hooks=use_activation_hooks,
use_fp8_all_gather=use_fp8_all_gather,
)
setattr(parent_module, module_name, float8linear_module)

Expand Down
4 changes: 2 additions & 2 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def tensor_to_amax(x, distributed_reduction=False):


@torch.no_grad()
def tensor_to_scale(x, float8_dtype):
amax = tensor_to_amax(x)
def tensor_to_scale(x, float8_dtype: torch.dtype, distributed_reduction: bool = False):
amax = tensor_to_amax(x, distributed_reduction=distributed_reduction)
return amax_to_scale(amax, float8_dtype, x.dtype)


Expand Down
2 changes: 1 addition & 1 deletion test/test_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ pytest test/test_compile.py
./test/test_fsdp.sh
./test/test_fsdp_compile.sh
./test/test_tp.sh
pytest test/test_fsdp/test_flat_param_fsdp_compile.py
pytest test/test_fsdp/*

echo "all tests successful"
Loading