Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Added use_activation_hooks: bool to swap #214

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,10 @@ def _update_history_stack(
def swap_linear_with_float8_linear(
module: nn.Module,
module_cls: Type[nn.Module],
emulate: bool = False,
*,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking that forcing these to be kwarg only would be less error prone.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I like it!

skip_fqn_list: Optional[List[str]] = None,
emulate: bool = False,
use_activation_hooks: bool = False,
) -> nn.Module:
"""
Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances
Expand All @@ -99,17 +101,20 @@ def swap_linear_with_float8_linear(
Args:
module (torch.nn.Module): Module to modify.
module_cls (Union[Type[Float8Linear], Type[Float8DynamicLinear]]): Float8 linear class for the swap.
emulate (bool, optional): Whether to emulate the fp8 matmul logic in fp32.
skip_fqn_list (List[str], optional): If specified, a list of module FQNs to skip.
Linear submodules of these skipped modules will also be skipped.
emulate (bool): Whether to emulate the fp8 matmul logic in fp32.
use_activation_hooks (bool): Whether to cast activations to fp8 using module hooks.
"""
module_names_to_skip = set(skip_fqn_list or [])
if isinstance(module, nn.Linear):
if len(list(module.children())) > 0:
raise AssertionError(
f"Does not support a root nn.Linear with children: {module}"
)
return module_cls.from_float(module, emulate)
return module_cls.from_float(
module, emulate=emulate, use_activation_hooks=use_activation_hooks
)

# Mark all modules to skip as visited
root_module = module
Expand All @@ -131,7 +136,10 @@ def post_order_traversal(
assert (
parent_module is not None
), f"Linear root module should return early: {module}"
setattr(parent_module, module_name, module_cls.from_float(module, emulate))
float8linear_module = module_cls.from_float(
module, emulate=emulate, use_activation_hooks=use_activation_hooks
)
setattr(parent_module, module_name, float8linear_module)

post_order_traversal(root_module, "", None)
# Without this explicit `del`, this set only gets deleted upon an explicit
Expand Down
8 changes: 4 additions & 4 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ def test_swap_root_linear(self):
[Float8Linear, Float8DynamicLinear], [True, False]
):
module = nn.Linear(3, 3)
module = swap_linear_with_float8_linear(module, module_cls, emulate)
module = swap_linear_with_float8_linear(module, module_cls, emulate=emulate)
self.assertIsInstance(module, module_cls)
self.assertEqual(module.emulate, emulate)

Expand All @@ -365,7 +365,7 @@ def test_swap_root_linear_with_children_raises(self):
AssertionError,
"Does not support a root nn.Linear with children",
):
swap_linear_with_float8_linear(module, module_cls, emulate)
swap_linear_with_float8_linear(module, module_cls, emulate=emulate)

def test_swap_submodule_linears(self):
class MLP(nn.Module):
Expand All @@ -378,7 +378,7 @@ def __init__(self, dim: int):
[Float8Linear, Float8DynamicLinear], [True, False]
):
model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
model = swap_linear_with_float8_linear(model, module_cls, emulate)
model = swap_linear_with_float8_linear(model, module_cls, emulate=emulate)
self.assertIsInstance(model[0].lin1, module_cls)
self.assertIsInstance(model[0].lin2, module_cls)
self.assertIsInstance(model[1], module_cls)
Expand All @@ -398,7 +398,7 @@ def __init__(self, dim: int):
model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
skip_fqn_list = ["2", "0.lin2"]
model = swap_linear_with_float8_linear(
model, module_cls, emulate, skip_fqn_list
model, module_cls, emulate=emulate, skip_fqn_list=skip_fqn_list
)
self.assertIsInstance(model[0].lin1, module_cls)
self.assertNotIsInstance(model[0].lin2, module_cls)
Expand Down