Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 88d9407

Browse files
vkuzofacebook-github-bot
authored andcommitted
switch inference tests to use Float8Linear (#297)
Summary: Pull Request resolved: #297 Since inference logic isn't used yet, saving some time and just switching the tests directly instead of testing both versions. Reviewed By: drisspg Differential Revision: D59305794 fbshipit-source-id: 8f05c26e923a762d6eb5d08676e76473c9c362b7
1 parent 412222b commit 88d9407

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

test/test_inference_flows.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch
1414
import torch.nn as nn
1515
import torch.nn.functional as F
16-
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
16+
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
1717
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
1818
from float8_experimental.float8_tensor import Float8Tensor
1919
from float8_experimental.float8_utils import compute_error
@@ -193,7 +193,10 @@ def test_fp8_save_and_load(self, dtype: torch.dtype):
193193
fp8_mlp.reset_parameters()
194194
swap_linear_with_float8_linear(
195195
fp8_mlp,
196-
Float8DynamicLinear,
196+
Float8Linear,
197+
scaling_type_x=TensorScalingType.DYNAMIC,
198+
scaling_type_w=TensorScalingType.DYNAMIC,
199+
scaling_type_dL_dY=TensorScalingType.DYNAMIC,
197200
)
198201

199202
# Train the model
@@ -210,12 +213,15 @@ def test_fp8_save_and_load(self, dtype: torch.dtype):
210213
# Reset buffer position to the beginning
211214
buffer.seek(0)
212215

213-
# Later on you load the model, will be w/ Float8DynamicLinear on meta device
216+
# Later on you load the model, will be w/ Float8Linear on meta device
214217
with torch.device("meta"):
215218
new_fp8_mlp = FeedForward().to(dtype=dtype)
216219
swap_linear_with_float8_linear(
217220
new_fp8_mlp,
218-
Float8DynamicLinear,
221+
Float8Linear,
222+
scaling_type_x=TensorScalingType.DYNAMIC,
223+
scaling_type_w=TensorScalingType.DYNAMIC,
224+
scaling_type_dL_dY=TensorScalingType.DYNAMIC,
219225
)
220226

221227
# Load the actual data

0 commit comments

Comments
 (0)