Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit bb0a585

Browse files
committed
update signature and add not about usage
1 parent e8853da commit bb0a585

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

float8_experimental/float8_linear_utils.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import copy
77
from enum import auto, Enum
88

9+
import float8_experimental.config as fp8_config
10+
911
import torch
1012
import torch.distributed as dist
1113
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
@@ -94,21 +96,23 @@ def swap_linear_with_float8_linear(
9496
swap_linear_with_float8_linear(child, module, emulate)
9597

9698

97-
def get_float8_layers(model: torch.nn.Module, fp8_classes=None):
98-
if fp8_classes is None:
99-
fp8_classes = Float8Linear
99+
def get_float8_layers(model: torch.nn.Module):
100+
"""Iterates through the model and returns all the Float8Linear layers.
101+
Args:
102+
model (torch.nn.Module): The model to look for Float8Linear layers in.
103+
"""
100104

101105
# Get all fp8 layers and tensors
102106
fp8_layers = [
103-
child for name, child in model.named_modules() if isinstance(child, fp8_classes)
107+
child
108+
for name, child in model.named_modules()
109+
if isinstance(child, Float8Linear)
104110
]
105111

106112
return fp8_layers
107113

108114

109-
def sync_float8_amax_and_scale_history(
110-
model: torch.nn.Module, fp8_classes=None, fp8_layers=None
111-
) -> None:
115+
def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) -> None:
112116
"""
113117
Manages the float8 amax and scale bookkeeping. In detail, it does the
114118
following:
@@ -120,11 +124,13 @@ def sync_float8_amax_and_scale_history(
120124
121125
TODO(future): design the UX for this (context manager, etc)
122126
127+
PERFORMANCE NOTE:
128+
When you can it is much more efficient to call te get_float8_layers once a
129+
the beginning of the training loop and pass the result to this function.
130+
Because of how this interacts with torch.compile
131+
123132
Args:
124133
model (torch.nn.Module): The model to track amaxes for
125-
fp8_classes (optional): The fp8 classes to look for in the model.
126-
The default is Float8Linear.
127-
When using with TP, users can pass in the customized TP classes instead.
128134
fp8_layers (optional): If fp8_layers are provided, fp8_classes are ignored,
129135
and we loop over all fp8_layers to sync and update amax scale histories.
130136
Users can use get_float8_layers to get all fp8 layers.
@@ -136,7 +142,7 @@ def sync_float8_amax_and_scale_history(
136142
# make the history update faster.
137143

138144
if fp8_layers is None:
139-
fp8_layers = get_float8_layers(model, fp8_classes)
145+
fp8_layers = get_float8_layers(model)
140146

141147
if dist.is_initialized():
142148
fp8_amax_x_tensor = torch.tensor(
@@ -210,5 +216,6 @@ def sync_float8_amax_and_scale_history(
210216

211217
#
212218
# 4. set a flag to signal amaxes/scales are ready
213-
#
214-
child.amax_and_scale_synced = True
219+
# We only update the flag if we know it will be checked by the modules
220+
if fp8_config.enable_amax_init:
221+
child.amax_and_scale_synced = True

0 commit comments

Comments
 (0)