Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 34e30d1

Browse files
author
Andrew Gu
committed
Added use_activation_hooks: bool to swap
ghstack-source-id: 65dd688 Pull Request resolved: #214
1 parent 0af8433 commit 34e30d1

File tree

3 files changed

+18
-8
lines changed

3 files changed

+18
-8
lines changed

float8_experimental/float8_linear_utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,10 @@ def _update_history_with_new_amax(new_amax, amax_history):
7070
def swap_linear_with_float8_linear(
7171
module: nn.Module,
7272
module_cls: Type[nn.Module],
73-
emulate: bool = False,
73+
*,
7474
skip_fqn_list: Optional[List[str]] = None,
75+
emulate: bool = False,
76+
use_activation_hooks: bool = False,
7577
) -> nn.Module:
7678
"""
7779
Replaces all instances of ``torch.nn.Linear`` in ``module`` with instances
@@ -80,17 +82,20 @@ def swap_linear_with_float8_linear(
8082
Args:
8183
module (torch.nn.Module): Module to modify.
8284
module_cls (Union[Type[Float8Linear], Type[Float8DynamicLinear]]): Float8 linear class for the swap.
83-
emulate (bool, optional): Whether to emulate the fp8 matmul logic in fp32.
8485
skip_fqn_list (List[str], optional): If specified, a list of module FQNs to skip.
8586
Linear submodules of these skipped modules will also be skipped.
87+
emulate (bool): Whether to emulate the fp8 matmul logic in fp32.
88+
use_activation_hooks (bool): Whether to cast activations to fp8 using module hooks.
8689
"""
8790
module_names_to_skip = set(skip_fqn_list or [])
8891
if isinstance(module, nn.Linear):
8992
if len(list(module.children())) > 0:
9093
raise AssertionError(
9194
f"Does not support a root nn.Linear with children: {module}"
9295
)
93-
return module_cls.from_float(module, emulate)
96+
return module_cls.from_float(
97+
module, emulate=emulate, use_activation_hooks=use_activation_hooks
98+
)
9499

95100
# Mark all modules to skip as visited
96101
root_module = module
@@ -112,7 +117,10 @@ def post_order_traversal(
112117
assert (
113118
parent_module is not None
114119
), f"Linear root module should return early: {module}"
115-
setattr(parent_module, module_name, module_cls.from_float(module, emulate))
120+
float8linear_module = module_cls.from_float(
121+
module, emulate=emulate, use_activation_hooks=use_activation_hooks
122+
)
123+
setattr(parent_module, module_name, float8linear_module)
116124

117125
post_order_traversal(root_module, "", None)
118126
# Without this explicit `del`, this set only gets deleted upon an explicit

float8_experimental/float8_python_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
from typing import Optional, Tuple
1414

15+
import float8_experimental.float8_aten_api # noqa
16+
1517
import torch
1618
from float8_experimental.float8_tensor import Float8Tensor
1719

test/test_base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def test_swap_root_linear(self):
351351
[Float8Linear, Float8DynamicLinear], [True, False]
352352
):
353353
module = nn.Linear(3, 3)
354-
module = swap_linear_with_float8_linear(module, module_cls, emulate)
354+
module = swap_linear_with_float8_linear(module, module_cls, emulate=emulate)
355355
self.assertIsInstance(module, module_cls)
356356
self.assertEqual(module.emulate, emulate)
357357

@@ -365,7 +365,7 @@ def test_swap_root_linear_with_children_raises(self):
365365
AssertionError,
366366
"Does not support a root nn.Linear with children",
367367
):
368-
swap_linear_with_float8_linear(module, module_cls, emulate)
368+
swap_linear_with_float8_linear(module, module_cls, emulate=emulate)
369369

370370
def test_swap_submodule_linears(self):
371371
class MLP(nn.Module):
@@ -378,7 +378,7 @@ def __init__(self, dim: int):
378378
[Float8Linear, Float8DynamicLinear], [True, False]
379379
):
380380
model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
381-
model = swap_linear_with_float8_linear(model, module_cls, emulate)
381+
model = swap_linear_with_float8_linear(model, module_cls, emulate=emulate)
382382
self.assertIsInstance(model[0].lin1, module_cls)
383383
self.assertIsInstance(model[0].lin2, module_cls)
384384
self.assertIsInstance(model[1], module_cls)
@@ -398,7 +398,7 @@ def __init__(self, dim: int):
398398
model = nn.Sequential(MLP(3), nn.Linear(3, 3), MLP(3))
399399
skip_fqn_list = ["2", "0.lin2"]
400400
model = swap_linear_with_float8_linear(
401-
model, module_cls, emulate, skip_fqn_list
401+
model, module_cls, emulate=emulate, skip_fqn_list=skip_fqn_list
402402
)
403403
self.assertIsInstance(model[0].lin1, module_cls)
404404
self.assertNotIsInstance(model[0].lin2, module_cls)

0 commit comments

Comments
 (0)