Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 29c059b

Browse files
committed
preallocate lists
1 parent e4f16f5 commit 29c059b

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

float8_experimental/float8_linear_utils.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -169,25 +169,25 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
169169
return
170170

171171
# Loop over all fp8 layers and grab the needed tensors
172-
fp8_amax_x_tensor_list = []
173-
fp8_amax_w_tensor_list = []
174-
fp8_amax_dL_dY_tensor_list = []
172+
fp8_amax_x_tensor_list = [None] * len(fp8_layers)
173+
fp8_amax_w_tensor_list = [None] * len(fp8_layers)
174+
fp8_amax_dL_dY_tensor_list = [None] * len(fp8_layers)
175175

176-
fp8_x_amax_history_stack = []
177-
fp8_w_amax_history_stack = []
178-
fp8_dL_dY_amax_history_stack = []
176+
fp8_x_amax_history_stack = [None] * len(fp8_layers)
177+
fp8_w_amax_history_stack = [None] * len(fp8_layers)
178+
fp8_dL_dY_amax_history_stack = [None] * len(fp8_layers)
179179

180180
x_dtypes = set()
181181
scale_fn_recipes = set()
182182

183-
for child in fp8_layers:
184-
fp8_amax_x_tensor_list.append(child.fp8_amax_x)
185-
fp8_amax_w_tensor_list.append(child.fp8_amax_w)
186-
fp8_amax_dL_dY_tensor_list.append(child.fp8_amax_dL_dY)
183+
for idx, child in enumerate(fp8_layers):
184+
fp8_amax_x_tensor_list[idx] = child.fp8_amax_x
185+
fp8_amax_w_tensor_list[idx] = child.fp8_amax_w
186+
fp8_amax_dL_dY_tensor_list[idx] = child.fp8_amax_dL_dY
187187

188-
fp8_x_amax_history_stack.append(child.fp8_amax_history_x)
189-
fp8_w_amax_history_stack.append(child.fp8_amax_history_w)
190-
fp8_dL_dY_amax_history_stack.append(child.fp8_amax_history_dL_dY)
188+
fp8_x_amax_history_stack[idx] = child.fp8_amax_history_x
189+
fp8_w_amax_history_stack[idx] = child.fp8_amax_history_w
190+
fp8_dL_dY_amax_history_stack[idx] = child.fp8_amax_history_dL_dY
191191

192192
x_dtypes.add(child.last_seen_input_dtype)
193193
scale_fn_recipes.add(child.recipe.scale_fn_name)

0 commit comments

Comments
 (0)