Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

[6/x] switch inference tests to use Float8Linear #297

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions test/test_inference_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
from float8_experimental.float8_linear import Float8Linear, TensorScalingType
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from float8_experimental.float8_tensor import Float8Tensor
from float8_experimental.float8_utils import compute_error
Expand Down Expand Up @@ -193,7 +193,10 @@ def test_fp8_save_and_load(self, dtype: torch.dtype):
fp8_mlp.reset_parameters()
swap_linear_with_float8_linear(
fp8_mlp,
Float8DynamicLinear,
Float8Linear,
scaling_type_x=TensorScalingType.DYNAMIC,
scaling_type_w=TensorScalingType.DYNAMIC,
scaling_type_dL_dY=TensorScalingType.DYNAMIC,
)

# Train the model
Expand All @@ -210,12 +213,15 @@ def test_fp8_save_and_load(self, dtype: torch.dtype):
# Reset buffer position to the beginning
buffer.seek(0)

# Later on you load the model, will be w/ Float8DynamicLinear on meta device
# Later on you load the model, will be w/ Float8Linear on meta device
with torch.device("meta"):
new_fp8_mlp = FeedForward().to(dtype=dtype)
swap_linear_with_float8_linear(
new_fp8_mlp,
Float8DynamicLinear,
Float8Linear,
scaling_type_x=TensorScalingType.DYNAMIC,
scaling_type_w=TensorScalingType.DYNAMIC,
scaling_type_dL_dY=TensorScalingType.DYNAMIC,
)

# Load the actual data
Expand Down
Loading