Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 2ea0ab5

Browse files
committed
enable weights only loading
1 parent 2ff810c commit 2ea0ab5

File tree

3 files changed

+27
-4
lines changed

3 files changed

+27
-4
lines changed

float8_experimental/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66
# Lets define a few top level things here
77
from float8_experimental.float8_linear import Float8Linear
8-
from float8_experimental.float8_tensor import Float8Tensor
8+
from float8_experimental.float8_tensor import Float8Tensor, ScaledMMConfig
99

1010
# Needed to load Float8Tensor with weights_only = True
1111
from torch.serialization import add_safe_globals
1212

13-
add_safe_globals([Float8Tensor])
13+
add_safe_globals([Float8Tensor, ScaledMMConfig])
1414

1515
__all__ = ["Float8Tensor", "Float8Linear"]

test/test_base.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import io
67
import itertools
78
import random
89
import unittest
@@ -12,6 +13,7 @@
1213

1314
import torch
1415
import torch.nn as nn
16+
1517
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
1618
from float8_experimental.float8_linear import Float8Linear
1719
from float8_experimental.float8_linear_utils import (
@@ -82,6 +84,25 @@ def test_split_cat(self):
8284
catted = torch.cat(splits, dim=0)
8385
assert bitwise_identical(fp8_a, catted)
8486

87+
def test_weights_only_load(self):
88+
module = nn.Linear(16, 16)
89+
# Save model state dict
90+
buffer = io.BytesIO()
91+
fp8_module = swap_linear_with_float8_linear(
92+
module,
93+
Float8DynamicLinear,
94+
from_float_kwargs={
95+
"pre_quantize_weight": True,
96+
"activation_scale": torch.tensor(
97+
[1.0], device="cuda", dtype=torch.float32
98+
),
99+
},
100+
)
101+
102+
torch.save(fp8_module.state_dict(), buffer)
103+
buffer.seek(0)
104+
_ = torch.load(buffer, weights_only=True)
105+
85106

86107
class TestFloat8Linear:
87108
def _test_linear_impl(

test/test_inference_flows.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,12 +179,14 @@ def test_fp8_save_and_load(self, compile_backend: str, dtype: torch.dtype):
179179
)
180180

181181
# Load the actual data
182-
new_fp8_mlp.load_state_dict(torch.load(buffer), strict=True, assign=True)
182+
new_fp8_mlp.load_state_dict(
183+
torch.load(buffer, weights_only=True), strict=True, assign=True
184+
)
183185

184186
# Dynamic Activations + Quantized Weights
185187
def quantize_dynamic_linear(x: nn.Module):
186188
if isinstance(x, Float8DynamicLinear):
187-
x.set_quantization_scales(True)
189+
x.set_quantization_scales(pre_quantize_weight=True)
188190
return x
189191

190192
new_fp8_mlp.apply(quantize_dynamic_linear)

0 commit comments

Comments
 (0)