@@ -169,25 +169,25 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
169
169
return
170
170
171
171
# Loop over all fp8 layers and grab the needed tensors
172
- fp8_amax_x_tensor_list = []
173
- fp8_amax_w_tensor_list = []
174
- fp8_amax_dL_dY_tensor_list = []
172
+ fp8_amax_x_tensor_list = [None ] * len ( fp8_layers )
173
+ fp8_amax_w_tensor_list = [None ] * len ( fp8_layers )
174
+ fp8_amax_dL_dY_tensor_list = [None ] * len ( fp8_layers )
175
175
176
- fp8_x_amax_history_stack = []
177
- fp8_w_amax_history_stack = []
178
- fp8_dL_dY_amax_history_stack = []
176
+ fp8_x_amax_history_stack = [None ] * len ( fp8_layers )
177
+ fp8_w_amax_history_stack = [None ] * len ( fp8_layers )
178
+ fp8_dL_dY_amax_history_stack = [None ] * len ( fp8_layers )
179
179
180
180
x_dtypes = set ()
181
181
scale_fn_recipes = set ()
182
182
183
- for child in fp8_layers :
184
- fp8_amax_x_tensor_list . append ( child .fp8_amax_x )
185
- fp8_amax_w_tensor_list . append ( child .fp8_amax_w )
186
- fp8_amax_dL_dY_tensor_list . append ( child .fp8_amax_dL_dY )
183
+ for idx , child in enumerate ( fp8_layers ) :
184
+ fp8_amax_x_tensor_list [ idx ] = child .fp8_amax_x
185
+ fp8_amax_w_tensor_list [ idx ] = child .fp8_amax_w
186
+ fp8_amax_dL_dY_tensor_list [ idx ] = child .fp8_amax_dL_dY
187
187
188
- fp8_x_amax_history_stack . append ( child .fp8_amax_history_x )
189
- fp8_w_amax_history_stack . append ( child .fp8_amax_history_w )
190
- fp8_dL_dY_amax_history_stack . append ( child .fp8_amax_history_dL_dY )
188
+ fp8_x_amax_history_stack [ idx ] = child .fp8_amax_history_x
189
+ fp8_w_amax_history_stack [ idx ] = child .fp8_amax_history_w
190
+ fp8_dL_dY_amax_history_stack [ idx ] = child .fp8_amax_history_dL_dY
191
191
192
192
x_dtypes .add (child .last_seen_input_dtype )
193
193
scale_fn_recipes .add (child .recipe .scale_fn_name )
0 commit comments