This repository was archived by the owner on Aug 7, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 19
Removed module
arg from fsdp_pre_all_gather
#217
Closed
Closed
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
fcd7115
Removed `module` arg from `fsdp_pre_all_gather`
839bcb3
Update on "Removed `module` arg from `fsdp_pre_all_gather`"
d839417
Update on "Removed `module` arg from `fsdp_pre_all_gather`"
6efa0bc
Update on "Removed `module` arg from `fsdp_pre_all_gather`"
7ef9fe2
Update on "Removed `module` arg from `fsdp_pre_all_gather`"
bbfb405
Update on "Removed `module` arg from `fsdp_pre_all_gather`"
c10d100
Update on "Removed `module` arg from `fsdp_pre_all_gather`"
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,8 +13,9 @@ | |
""" | ||
|
||
import dataclasses | ||
import functools | ||
|
||
from typing import Any, cast, Optional, Tuple, Union | ||
from typing import Any, Callable, cast, Optional, Tuple, Union | ||
|
||
import float8_experimental.config as config | ||
|
||
|
@@ -29,6 +30,7 @@ | |
tensor_to_amax, | ||
to_fp8_saturated, | ||
) | ||
from torch.utils._pytree import tree_map | ||
|
||
|
||
def _maybe_initialize_amaxes_scales_for_float8_cast( | ||
|
@@ -222,9 +224,7 @@ def cast_x_to_float8( | |
) | ||
return x_fp8 | ||
|
||
def cast_w_to_float8( | ||
self, w: torch.Tensor, is_amax_initialized: bool | ||
) -> torch.Tensor: | ||
def cast_w_to_float8(self, w: torch.Tensor) -> torch.Tensor: | ||
scale_fn_name = self.recipe.scale_fn_name | ||
_maybe_initialize_amaxes_scales_for_float8_cast( | ||
w, | ||
|
@@ -233,7 +233,7 @@ def cast_w_to_float8( | |
self.fp8_scale_w, | ||
scale_fn_name, | ||
torch.float8_e4m3fn, | ||
is_amax_initialized, | ||
self.is_amax_initialized, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have to push this |
||
) | ||
|
||
w_fp8 = Float8Tensor.to_float8( | ||
|
@@ -297,7 +297,7 @@ def forward(self, x): | |
w_fp8 = ( | ||
self.weight | ||
if isinstance(self.weight, Float8Tensor) | ||
else self.cast_w_to_float8(self.weight, self.is_amax_initialized) | ||
else self.cast_w_to_float8(self.weight) | ||
) | ||
|
||
y = torch.matmul(x_fp8, w_fp8.t()) | ||
|
@@ -333,7 +333,9 @@ def from_float( | |
# with torch.device("meta"): | ||
new_mod = cls(mod.in_features, mod.out_features, bias=False) | ||
new_mod.weight = ( | ||
nn.Parameter(Float8LinearWeightTensor(mod.weight)) | ||
nn.Parameter( | ||
Float8LinearWeightTensor(mod.weight, new_mod.cast_w_to_float8, emulate) | ||
) | ||
if use_fp8_all_gather | ||
else mod.weight | ||
) | ||
|
@@ -345,12 +347,34 @@ def from_float( | |
|
||
|
||
class Float8LinearWeightTensor(torch.Tensor): | ||
# TODO: Remove `module` arg, save state on subclass, and propagate it. | ||
def fsdp_pre_all_gather( | ||
self, module: nn.Module | ||
) -> Tuple[Tuple[torch.Tensor, ...], Any]: | ||
float8_tensor = module.cast_w_to_float8(self, module.is_amax_initialized) | ||
return (float8_tensor._data,), (float8_tensor._scale, module.emulate) | ||
def __new__(cls, tensor: torch.Tensor, cast_fn: Callable, emulate: bool): | ||
return cls._make_subclass(cls, tensor, tensor.requires_grad) | ||
|
||
def __init__(self, tensor: torch.Tensor, cast_fn: Callable, emulate: bool): | ||
super().__init__() | ||
self.cast_fn = cast_fn | ||
self.emulate = emulate | ||
|
||
@classmethod | ||
def __torch_function__(cls, func, types, args=(), kwargs=None): | ||
kwargs = kwargs or {} | ||
|
||
def wrap(cast_fn: Callable, emulate: bool, o: Any): | ||
if isinstance(o, torch.Tensor) and not isinstance(o, cls): | ||
return cls(o, cast_fn, emulate) | ||
return o | ||
|
||
with torch._C.DisableTorchFunctionSubclass(): | ||
if isinstance(args[0], cls): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The heuristic here is just to propagate the |
||
out = func(*args, **kwargs) | ||
return tree_map( | ||
functools.partial(wrap, args[0].cast_fn, args[0].emulate), out | ||
) | ||
return func(*args, **kwargs) | ||
|
||
def fsdp_pre_all_gather(self) -> Tuple[Tuple[torch.Tensor, ...], Any]: | ||
float8_tensor = self.cast_fn(self) | ||
return (float8_tensor._data,), (float8_tensor._scale,) | ||
|
||
def fsdp_post_all_gather( | ||
self, | ||
|
@@ -361,7 +385,7 @@ def fsdp_post_all_gather( | |
out: Optional[torch.Tensor] = None, | ||
) -> Union[Tuple[Float8Tensor, Tuple[torch.Tensor, ...]], None]: | ||
(data,) = all_gather_outputs | ||
scale, emulate = metadata | ||
(scale,) = metadata | ||
if out is not None: | ||
out = cast(Float8Tensor, out) | ||
assert ( | ||
|
@@ -370,4 +394,4 @@ def fsdp_post_all_gather( | |
) | ||
out._scale = scale | ||
return | ||
return Float8Tensor(data, scale, param_dtype, emulate), (data,) | ||
return Float8Tensor(data, scale, param_dtype, self.emulate), (data,) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem statement is that in order to implement the pre-all-gather transform, the subclass needs some state (e.g. mainly emulate but preferably the cast function as well). In the first prototype, I shortcutted by passing
module
intofsdp_pre_all_gather()
so that the transform could read state off the module instead of storing it on the subclass itself.However, the morally right thing (from a design perspective) should be to put that state on the subclass. I wonder though, how does this kind of
__torch_function__
subclass interact withtorch.compile
today. cc: @bdhirsh