Skip to content

Commit 4b7c6db

Browse files
GregoryComerfacebook-github-bot
authored andcommitted
Add fp16 qb4w linear test coverage (#3626)
Summary: Pull Request resolved: #3626 Add test coverage in ExecuTorch op-level linear test coverage for 4-bit blockwise weights / fp16. Reviewed By: digantdesai Differential Revision: D57335871 fbshipit-source-id: 4813941c1ac63e21fcb5b5c4f563720b7c74ba79
1 parent a397bff commit 4b7c6db

File tree

1 file changed

+37
-4
lines changed

1 file changed

+37
-4
lines changed

backends/xnnpack/test/ops/linear.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -412,9 +412,10 @@ def __init__(
412412
)
413413
self.quant_weight_per_channel()
414414

415-
# TODO - change bias dtyoe to arg.dtype
416415
self.bias = (
417-
torch.nn.Parameter(torch.randn(self.oc), requires_grad=False)
416+
torch.nn.Parameter(
417+
torch.randn(self.oc).to(self.op_dtype), requires_grad=False
418+
)
418419
if use_bias
419420
else None
420421
)
@@ -595,14 +596,14 @@ def fwd_weight_per_channel_group(self) -> torch.Tensor:
595596

596597
def forward(self, input: torch.Tensor) -> torch.Tensor:
597598
# Input
598-
input = self.fwd_input_per_token(input)
599+
input = self.fwd_input_per_token(input).to(self.op_dtype)
599600

600601
# Weights
601602
w = (
602603
self.fwd_weight_per_channel_group()
603604
if self.w_scales.ndim == 2
604605
else self.fwd_weight_per_channel()
605-
)
606+
).to(self.op_dtype)
606607
assert isinstance(w, torch.Tensor)
607608
return torch.nn.functional.linear(input, w, self.bias)
608609

@@ -734,6 +735,38 @@ def test_qd8_fp32_per_token_weight_per_channel_group_int4(self):
734735
use_bias=use_bias,
735736
)
736737

738+
def test_qd8_fp16_per_token_weight_per_channel_group_int4(self):
739+
M_sizes = [1, 2, 17, 31]
740+
K_sizes = [8, 32, 64, 128]
741+
bl_sizes = [8, 16, 16, 32]
742+
N_sizes = [2, 17, 92, 128]
743+
744+
for use_bias in [True, False]:
745+
for i, _ in enumerate(M_sizes):
746+
M = int(M_sizes[i])
747+
K = int(K_sizes[i])
748+
N = int(N_sizes[i])
749+
bl = int(bl_sizes[i])
750+
mod = self.ManualDQLinear(
751+
input_channels=K,
752+
output_channels=N,
753+
weight_n_bit=4,
754+
dtype=torch.float16,
755+
group_size=bl,
756+
force_groupwise_quant=True,
757+
use_bias=use_bias,
758+
)
759+
760+
inputs = (torch.randn(1, M, K, dtype=torch.float16),)
761+
self._test_manual_dq_linear(
762+
mod,
763+
inputs,
764+
weight_groupwise=True,
765+
use_bias=use_bias,
766+
atol=0.1,
767+
rtol=0.1,
768+
)
769+
737770
def _test_linear(
738771
self,
739772
make_module,

0 commit comments

Comments
 (0)