Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit b9b37f8

Browse files
Andrew Gufacebook-github-bot
authored andcommitted
Added use_activation_hooks: bool to swap (#214)
Summary: Pull Request resolved: #214 **Test Plan** ``` ./test/test_everything.sh ``` imported-using-ghimport Test Plan: Imported from OSS Reviewed By: drisspg Differential Revision: D53780333 Pulled By: awgu fbshipit-source-id: 91dd1aba3173f2a9784a3f460ad957ea1c21d116
1 parent 6f22688 commit b9b37f8

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

float8_experimental/float8_linear_utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,10 @@ def _update_history_stack(
8989
def swap_linear_with_float8_linear(
9090
module: nn.Module,
9191
module_cls: Type[nn.Module],
92-
emulate: bool = False,
92+
*,
9393
skip_fqn_list: Optional[List[str]] = None,
94+
emulate: bool = False,
95+
use_activation_hooks: bool = False,
9496
) -> nn.Module:
9597
"""
9698
Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances
@@ -99,17 +101,20 @@ def swap_linear_with_float8_linear(
99101
Args:
100102
module (torch.nn.Module): Module to modify.
101103
module_cls (Union[Type[Float8Linear], Type[Float8DynamicLinear]]): Float8 linear class for the swap.
102-
emulate (bool, optional): Whether to emulate the fp8 matmul logic in fp32.
103104
skip_fqn_list (List[str], optional): If specified, a list of module FQNs to skip.
104105
Linear submodules of these skipped modules will also be skipped.
106+
emulate (bool): Whether to emulate the fp8 matmul logic in fp32.
107+
use_activation_hooks (bool): Whether to cast activations to fp8 using module hooks.
105108
"""
106109
module_names_to_skip = set(skip_fqn_list or [])
107110
if isinstance(module, nn.Linear):
108111
if len(list(module.children())) > 0:
109112
raise AssertionError(
110113
f"Does not support a root nn.Linear with children: {module}"
111114
)
112-
return module_cls.from_float(module, emulate)
115+
return module_cls.from_float(
116+
module, emulate=emulate, use_activation_hooks=use_activation_hooks
117+
)
113118

114119
# Mark all modules to skip as visited
115120
root_module = module
@@ -131,7 +136,10 @@ def post_order_traversal(
131136
assert (
132137
parent_module is not None
133138
), f"Linear root module should return early: {module}"
134-
setattr(parent_module, module_name, module_cls.from_float(module, emulate))
139+
float8linear_module = module_cls.from_float(
140+
module, emulate=emulate, use_activation_hooks=use_activation_hooks
141+
)
142+
setattr(parent_module, module_name, float8linear_module)
135143

136144
post_order_traversal(root_module, "", None)
137145
# Without this explicit `del`, this set only gets deleted upon an explicit

test/test_base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def test_swap_root_linear(self):
351351
[Float8Linear, Float8DynamicLinear], [True, False]
352352
):
353353
module = nn.Linear(3, 3)
354-
module = swap_linear_with_float8_linear(module, module_cls, emulate)
354+
module = swap_linear_with_float8_linear(module, module_cls, emulate=emulate)
355355
self.assertIsInstance(module, module_cls)
356356
self.assertEqual(module.emulate, emulate)
357357

@@ -365,7 +365,7 @@ def test_swap_root_linear_with_children_raises(self):
365365
AssertionError,
366366
"Does not support a root nn.Linear with children",
367367
):
368-
swap_linear_with_float8_linear(module, module_cls, emulate)
368+
swap_linear_with_float8_linear(module, module_cls, emulate=emulate)
369369

370370
def test_swap_submodule_linears(self):
371371
class MLP(nn.Module):
@@ -378,7 +378,7 @@ def __init__(self, dim: int):
378378
[Float8Linear, Float8DynamicLinear], [True, False]
379379
):
380380
model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
381-
model = swap_linear_with_float8_linear(model, module_cls, emulate)
381+
model = swap_linear_with_float8_linear(model, module_cls, emulate=emulate)
382382
self.assertIsInstance(model[0].lin1, module_cls)
383383
self.assertIsInstance(model[0].lin2, module_cls)
384384
self.assertIsInstance(model[1], module_cls)
@@ -398,7 +398,7 @@ def __init__(self, dim: int):
398398
model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
399399
skip_fqn_list = ["2", "0.lin2"]
400400
model = swap_linear_with_float8_linear(
401-
model, module_cls, emulate, skip_fqn_list
401+
model, module_cls, emulate=emulate, skip_fqn_list=skip_fqn_list
402402
)
403403
self.assertIsInstance(model[0].lin1, module_cls)
404404
self.assertNotIsInstance(model[0].lin2, module_cls)

0 commit comments

Comments
 (0)