Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Removed module arg from fsdp_pre_all_gather #217

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 41 additions & 14 deletions float8_experimental/float8_dynamic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
A wrapper around a `torch.nn.Linear` module which does fp8 compute.
"""

from typing import Any, cast, Optional, Tuple, Union
import functools
from typing import Any, Callable, cast, Optional, Tuple, Union

import torch
import torch.nn as nn

from float8_experimental.float8_tensor import Float8Tensor, to_fp8_no_autograd
from float8_experimental.float8_utils import tensor_to_scale

from torch.utils._pytree import tree_map


@torch._dynamo.allow_in_graph
class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
Expand Down Expand Up @@ -125,11 +128,13 @@ def from_float(
"bias": False,
}
new_mod = cls(use_activation_hooks, **super_kwargs)
new_mod.weight = (
nn.Parameter(Float8DynamicLinearWeightTensor(mod.weight))
if use_fp8_all_gather
else mod.weight
)
if use_fp8_all_gather:
cast_fn = new_mod.cast_to_float8_e4m3fn
new_mod.weight = nn.Parameter(
Float8DynamicLinearWeightTensor(mod.weight, cast_fn, emulate)
)
else:
new_mod.weight = mod.weight
new_mod.bias = mod.bias
new_mod.emulate = emulate
if new_mod.use_activation_hooks:
Expand All @@ -142,12 +147,34 @@ def from_float(


class Float8DynamicLinearWeightTensor(torch.Tensor):
# TODO: Remove `module` arg, save state on subclass, and propagate it.
def fsdp_pre_all_gather(
self, module: nn.Module
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem statement is that in order to implement the pre-all-gather transform, the subclass needs some state (e.g. mainly emulate but preferably the cast function as well). In the first prototype, I shortcutted by passing module into fsdp_pre_all_gather() so that the transform could read state off the module instead of storing it on the subclass itself.

However, the morally right thing (from a design perspective) should be to put that state on the subclass. I wonder though, how does this kind of __torch_function__ subclass interact with torch.compile today. cc: @bdhirsh

) -> Tuple[Tuple[torch.Tensor, ...], Any]:
float8_tensor = module.cast_to_float8_e4m3fn(self, reduce_amax=True)
return (float8_tensor._data,), (float8_tensor._scale, module.emulate)
def __new__(cls, tensor: torch.Tensor, cast_fn: Callable, emulate: bool):
return cls._make_subclass(cls, tensor, tensor.requires_grad)

def __init__(self, tensor: torch.Tensor, cast_fn: Callable, emulate: bool):
super().__init__()
self.cast_fn = cast_fn
self.emulate = emulate

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}

def wrap(cast_fn: Callable, emulate: bool, o: Any):
if isinstance(o, torch.Tensor) and not isinstance(o, cls):
return cls(o, cast_fn, emulate)
return o

with torch._C.DisableTorchFunctionSubclass():
if isinstance(args[0], cls):
out = func(*args, **kwargs)
return tree_map(
functools.partial(wrap, args[0].cast_fn, args[0].emulate), out
)
return func(*args, **kwargs)

def fsdp_pre_all_gather(self) -> Tuple[Tuple[torch.Tensor, ...], Any]:
float8_tensor = self.cast_fn(self, reduce_amax=True)
return (float8_tensor._data,), (float8_tensor._scale,)

def fsdp_post_all_gather(
self,
Expand All @@ -158,7 +185,7 @@ def fsdp_post_all_gather(
out: Optional[torch.Tensor] = None,
) -> Union[Tuple[Float8Tensor, Tuple[torch.Tensor, ...]], None]:
(data,) = all_gather_outputs
scale, emulate = metadata
(scale,) = metadata
if out is not None:
out = cast(Float8Tensor, out)
assert (
Expand All @@ -167,4 +194,4 @@ def fsdp_post_all_gather(
)
out._scale = scale
return
return Float8Tensor(data, scale, param_dtype, emulate), (data,)
return Float8Tensor(data, scale, param_dtype, self.emulate), (data,)
54 changes: 39 additions & 15 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
"""

import dataclasses
import functools

from typing import Any, cast, Optional, Tuple, Union
from typing import Any, Callable, cast, Optional, Tuple, Union

import float8_experimental.config as config

Expand All @@ -29,6 +30,7 @@
tensor_to_amax,
to_fp8_saturated,
)
from torch.utils._pytree import tree_map


def _maybe_initialize_amaxes_scales_for_float8_cast(
Expand Down Expand Up @@ -222,9 +224,7 @@ def cast_x_to_float8(
)
return x_fp8

def cast_w_to_float8(
self, w: torch.Tensor, is_amax_initialized: bool
) -> torch.Tensor:
def cast_w_to_float8(self, w: torch.Tensor) -> torch.Tensor:
scale_fn_name = self.recipe.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
w,
Expand All @@ -233,7 +233,7 @@ def cast_w_to_float8(
self.fp8_scale_w,
scale_fn_name,
torch.float8_e4m3fn,
is_amax_initialized,
self.is_amax_initialized,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to push this is_amax_initialized into the cast function since the Float8LinearWeightTensor subclass cannot keep a reference to the bool.

)

w_fp8 = Float8Tensor.to_float8(
Expand Down Expand Up @@ -297,7 +297,7 @@ def forward(self, x):
w_fp8 = (
self.weight
if isinstance(self.weight, Float8Tensor)
else self.cast_w_to_float8(self.weight, self.is_amax_initialized)
else self.cast_w_to_float8(self.weight)
)

y = torch.matmul(x_fp8, w_fp8.t())
Expand Down Expand Up @@ -333,7 +333,9 @@ def from_float(
# with torch.device("meta"):
new_mod = cls(mod.in_features, mod.out_features, bias=False)
new_mod.weight = (
nn.Parameter(Float8LinearWeightTensor(mod.weight))
nn.Parameter(
Float8LinearWeightTensor(mod.weight, new_mod.cast_w_to_float8, emulate)
)
if use_fp8_all_gather
else mod.weight
)
Expand All @@ -345,12 +347,34 @@ def from_float(


class Float8LinearWeightTensor(torch.Tensor):
# TODO: Remove `module` arg, save state on subclass, and propagate it.
def fsdp_pre_all_gather(
self, module: nn.Module
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
float8_tensor = module.cast_w_to_float8(self, module.is_amax_initialized)
return (float8_tensor._data,), (float8_tensor._scale, module.emulate)
def __new__(cls, tensor: torch.Tensor, cast_fn: Callable, emulate: bool):
return cls._make_subclass(cls, tensor, tensor.requires_grad)

def __init__(self, tensor: torch.Tensor, cast_fn: Callable, emulate: bool):
super().__init__()
self.cast_fn = cast_fn
self.emulate = emulate

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = kwargs or {}

def wrap(cast_fn: Callable, emulate: bool, o: Any):
if isinstance(o, torch.Tensor) and not isinstance(o, cls):
return cls(o, cast_fn, emulate)
return o

with torch._C.DisableTorchFunctionSubclass():
if isinstance(args[0], cls):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The heuristic here is just to propagate the Float8LinearWeightTensor as long as it is the first argument.

out = func(*args, **kwargs)
return tree_map(
functools.partial(wrap, args[0].cast_fn, args[0].emulate), out
)
return func(*args, **kwargs)

def fsdp_pre_all_gather(self) -> Tuple[Tuple[torch.Tensor, ...], Any]:
float8_tensor = self.cast_fn(self)
return (float8_tensor._data,), (float8_tensor._scale,)

def fsdp_post_all_gather(
self,
Expand All @@ -361,7 +385,7 @@ def fsdp_post_all_gather(
out: Optional[torch.Tensor] = None,
) -> Union[Tuple[Float8Tensor, Tuple[torch.Tensor, ...]], None]:
(data,) = all_gather_outputs
scale, emulate = metadata
(scale,) = metadata
if out is not None:
out = cast(Float8Tensor, out)
assert (
Expand All @@ -370,4 +394,4 @@ def fsdp_post_all_gather(
)
out._scale = scale
return
return Float8Tensor(data, scale, param_dtype, emulate), (data,)
return Float8Tensor(data, scale, param_dtype, self.emulate), (data,)