Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit dc4b38f

Browse files
committed
add a test
1 parent ea368d1 commit dc4b38f

File tree

1 file changed

+19
-1
lines changed

1 file changed

+19
-1
lines changed

test/test_compile.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@
1111

1212
import torch
1313
import torch.nn as nn
14-
from float8_experimental.float8_linear_utils import get_float8_linear, LinearType
14+
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
15+
from float8_experimental.float8_linear_utils import (
16+
get_float8_linear,
17+
LinearType,
18+
swap_linear_with_float8_linear,
19+
sync_float8_amax_and_scale_history,
20+
)
1521
from float8_experimental.float8_tensor import Float8Tensor
1622

1723
from torch._dynamo.test_case import TestCase as DynamoTestCase
@@ -199,5 +205,17 @@ def test_float8_graph_output(self):
199205
)
200206

201207

208+
@unittest.skipIf(not torch.cuda.is_available() or not is_H100, "CUDA not available")
209+
def test_sync_amax_func():
210+
cnts = CompileCounterWithBackend("inductor")
211+
module = torch.nn.Sequential(
212+
nn.Linear(16, 32, bias=True), nn.ReLU(), nn.Linear(32, 16, bias=True)
213+
)
214+
float8_mod = swap_linear_with_float8_linear(module, Float8DynamicLinear)
215+
compiled_swap_func = torch.compile(sync_float8_amax_and_scale_history, backend=cnts)
216+
compiled_swap_func(float8_mod)
217+
assert cnts.frame_count == 0, "Compiled graph should have 1 frame!"
218+
219+
202220
if __name__ == "__main__":
203221
pytest.main([__file__])

0 commit comments

Comments
 (0)