Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit fa2f08a

Browse files
committed
unit test for precomputing scales
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 9ef67fb commit fa2f08a

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

test/test_fsdp2/test_fsdp2_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch
77
import torch.distributed as dist
88
import torch.nn as nn
9-
from float8_experimental.fsdp_utils import precompute_float8_amax_for_fsdp
9+
from float8_experimental.fsdp_utils import precompute_float8_scale_for_fsdp
1010

1111

1212
def check_parity_no_mp(
@@ -31,7 +31,7 @@ def check_parity_no_mp(
3131
# TODO(future): add amax syncing once delayed scaling is supported
3232
optim.step()
3333
if model is fsdp_model and precompute:
34-
precompute_float8_amax_for_fsdp(model)
34+
precompute_float8_scale_for_fsdp(model)
3535
test_cls.assertEqual(losses[0], losses[1])
3636

3737

0 commit comments

Comments
 (0)