6
6
import copy
7
7
from enum import auto , Enum
8
8
9
+ import float8_experimental .config as fp8_config
10
+
9
11
import torch
10
12
import torch .distributed as dist
11
13
from float8_experimental .float8_dynamic_linear import Float8DynamicLinear
@@ -94,21 +96,23 @@ def swap_linear_with_float8_linear(
94
96
swap_linear_with_float8_linear (child , module , emulate )
95
97
96
98
97
- def get_float8_layers (model : torch .nn .Module , fp8_classes = None ):
98
- if fp8_classes is None :
99
- fp8_classes = Float8Linear
99
+ def get_float8_layers (model : torch .nn .Module ):
100
+ """Iterates through the model and returns all the Float8Linear layers.
101
+ Args:
102
+ model (torch.nn.Module): The model to look for Float8Linear layers in.
103
+ """
100
104
101
105
# Get all fp8 layers and tensors
102
106
fp8_layers = [
103
- child for name , child in model .named_modules () if isinstance (child , fp8_classes )
107
+ child
108
+ for name , child in model .named_modules ()
109
+ if isinstance (child , Float8Linear )
104
110
]
105
111
106
112
return fp8_layers
107
113
108
114
109
- def sync_float8_amax_and_scale_history (
110
- model : torch .nn .Module , fp8_classes = None , fp8_layers = None
111
- ) -> None :
115
+ def sync_float8_amax_and_scale_history (model : torch .nn .Module , fp8_layers = None ) -> None :
112
116
"""
113
117
Manages the float8 amax and scale bookkeeping. In detail, it does the
114
118
following:
@@ -120,11 +124,13 @@ def sync_float8_amax_and_scale_history(
120
124
121
125
TODO(future): design the UX for this (context manager, etc)
122
126
127
+ PERFORMANCE NOTE:
128
+ When you can it is much more efficient to call te get_float8_layers once a
129
+ the beginning of the training loop and pass the result to this function.
130
+ Because of how this interacts with torch.compile
131
+
123
132
Args:
124
133
model (torch.nn.Module): The model to track amaxes for
125
- fp8_classes (optional): The fp8 classes to look for in the model.
126
- The default is Float8Linear.
127
- When using with TP, users can pass in the customized TP classes instead.
128
134
fp8_layers (optional): If fp8_layers are provided, fp8_classes are ignored,
129
135
and we loop over all fp8_layers to sync and update amax scale histories.
130
136
Users can use get_float8_layers to get all fp8 layers.
@@ -136,7 +142,7 @@ def sync_float8_amax_and_scale_history(
136
142
# make the history update faster.
137
143
138
144
if fp8_layers is None :
139
- fp8_layers = get_float8_layers (model , fp8_classes )
145
+ fp8_layers = get_float8_layers (model )
140
146
141
147
if dist .is_initialized ():
142
148
fp8_amax_x_tensor = torch .tensor (
@@ -210,5 +216,6 @@ def sync_float8_amax_and_scale_history(
210
216
211
217
#
212
218
# 4. set a flag to signal amaxes/scales are ready
213
- #
214
- child .amax_and_scale_synced = True
219
+ # We only update the flag if we know it will be checked by the modules
220
+ if fp8_config .enable_amax_init :
221
+ child .amax_and_scale_synced = True
0 commit comments