Skip to content

Commit fde8c70

Browse files
author
Nathanael See
committed
Update on "[ET-VK][int4] patch 4-bit source transformation quantizer to support linear modules with biases"
While LLaMa does not have biases, there are some models which will have biases in their linear modules. Add support in the source transform quantizer for biases. Differential Revision: [D69072087](https://our.internmc.facebook.com/intern/diff/D69072087/) [ghstack-poisoned]
2 parents 6e36efe + e5f28dc commit fde8c70

File tree

29 files changed

+983
-427
lines changed

29 files changed

+983
-427
lines changed

.buckconfig

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,6 @@
3333
**/.git, \
3434
cmake-out, \
3535
pip-out
36+
37+
[buck2]
38+
restarter=true

backends/arm/test/misc/test_tosa_spec.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,22 @@
2020
"TOSA-0.80+MI+8k",
2121
"TOSA-0.80+BI+u55",
2222
]
23-
test_valid_1_00_strings = [
24-
"TOSA-1.00.0+INT+FP+fft",
25-
"TOSA-1.00.0+FP+bf16+fft",
26-
"TOSA-1.00.0+INT+int4+cf",
27-
"TOSA-1.00.0+FP+cf+bf16+8k",
28-
"TOSA-1.00.0+FP+INT+bf16+fft+int4+cf",
29-
"TOSA-1.00.0+FP+INT+fft+int4+cf+8k",
23+
test_valid_1_0_strings = [
24+
"TOSA-1.0.0+INT+FP+fft",
25+
"TOSA-1.0.0+FP+bf16+fft",
26+
"TOSA-1.0.0+INT+int4+cf",
27+
"TOSA-1.0.0+FP+cf+bf16+8k",
28+
"TOSA-1.0.0+FP+INT+bf16+fft+int4+cf",
29+
"TOSA-1.0.0+FP+INT+fft+int4+cf+8k",
30+
"TOSA-1.0+INT+FP+fft",
31+
"TOSA-1.0+FP+bf16+fft",
32+
"TOSA-1.0+INT+int4+cf",
33+
"TOSA-1.0+FP+cf+bf16+8k",
34+
"TOSA-1.0+FP+INT+bf16+fft+int4+cf",
35+
"TOSA-1.0+FP+INT+fft+int4+cf+8k",
3036
]
3137

32-
test_valid_1_00_extensions = {
38+
test_valid_1_0_extensions = {
3339
"INT": ["int16", "int4", "var", "cf"],
3440
"FP": ["bf16", "fp8e4m3", "fp8e5m2", "fft", "var", "cf"],
3541
}
@@ -40,19 +46,19 @@
4046
"TOSA-0.80+8k",
4147
"TOSA-0.80+BI+MI",
4248
"TOSA-0.80+BI+U55",
43-
"TOSA-1.00.0+fft",
44-
"TOSA-1.00.0+fp+bf16+fft",
45-
"TOSA-1.00.0+INT+INT4+cf",
46-
"TOSA-1.00.0+BI",
47-
"TOSA-1.00.0+FP+FP+INT",
48-
"TOSA-1.00.0+FP+CF+bf16",
49-
"TOSA-1.00.0+BF16+fft+int4+cf+INT",
49+
"TOSA-1.0.0+fft",
50+
"TOSA-1.0.0+fp+bf16+fft",
51+
"TOSA-1.0.0+INT+INT4+cf",
52+
"TOSA-1.0.0+BI",
53+
"TOSA-1.0.0+FP+FP+INT",
54+
"TOSA-1.0.0+FP+CF+bf16",
55+
"TOSA-1.0.0+BF16+fft+int4+cf+INT",
5056
]
5157

5258
test_compile_specs = [
5359
([CompileSpec("tosa_version", "TOSA-0.80+BI".encode())],),
5460
([CompileSpec("tosa_version", "TOSA-0.80+BI+u55".encode())],),
55-
([CompileSpec("tosa_version", "TOSA-1.00.0+INT".encode())],),
61+
([CompileSpec("tosa_version", "TOSA-1.0.0+INT".encode())],),
5662
]
5763

5864
test_compile_specs_no_version = [
@@ -70,8 +76,8 @@ def test_version_string_0_80(self, version_string: str):
7076
assert isinstance(tosa_spec, Tosa_0_80)
7177
assert tosa_spec.profile in ["BI", "MI"]
7278

73-
@parameterized.expand(test_valid_1_00_strings) # type: ignore[misc]
74-
def test_version_string_1_00(self, version_string: str):
79+
@parameterized.expand(test_valid_1_0_strings) # type: ignore[misc]
80+
def test_version_string_1_0(self, version_string: str):
7581
tosa_spec = TosaSpecification.create_from_string(version_string)
7682
assert isinstance(tosa_spec, Tosa_1_00)
7783
assert [profile in ["INT", "FP"] for profile in tosa_spec.profiles].count(
@@ -80,7 +86,7 @@ def test_version_string_1_00(self, version_string: str):
8086

8187
for profile in tosa_spec.profiles:
8288
assert [
83-
e in test_valid_1_00_extensions[profile] for e in tosa_spec.extensions
89+
e in test_valid_1_0_extensions[profile] for e in tosa_spec.extensions
8490
]
8591

8692
@parameterized.expand(test_invalid_strings) # type: ignore[misc]
@@ -103,3 +109,15 @@ def test_create_from_invalid_compilespec(self, compile_specs: list[CompileSpec])
103109
tosa_spec = TosaSpecification.create_from_compilespecs(compile_specs)
104110

105111
assert tosa_spec is None
112+
113+
@parameterized.expand(test_valid_0_80_strings)
114+
def test_correct_string_representation_0_80(self, version_string: str):
115+
tosa_spec = TosaSpecification.create_from_string(version_string)
116+
assert isinstance(tosa_spec, Tosa_0_80)
117+
assert f"{tosa_spec}" == version_string
118+
119+
@parameterized.expand(test_valid_1_0_strings)
120+
def test_correct_string_representation_1_0(self, version_string: str):
121+
tosa_spec = TosaSpecification.create_from_string(version_string)
122+
assert isinstance(tosa_spec, Tosa_1_00)
123+
assert f"{tosa_spec}" == version_string

backends/arm/tosa_specification.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -14,7 +14,9 @@
1414
import re
1515
from typing import List
1616

17-
from executorch.exir.backend.compile_spec_schema import CompileSpec
17+
from executorch.exir.backend.compile_spec_schema import ( # type: ignore[import-untyped]
18+
CompileSpec,
19+
)
1820
from packaging.version import Version
1921

2022

@@ -131,7 +133,7 @@ def __init__(self, version: Version, extras: List[str]):
131133
def __repr__(self):
132134
extensions = ""
133135
if self.level_8k:
134-
extensions += "+8K"
136+
extensions += "+8k"
135137
if self.is_U55_subset:
136138
extensions += "+u55"
137139
return f"TOSA-{str(self.version)}+{self.profile}{extensions}"
@@ -207,7 +209,10 @@ def _get_extensions_string(self) -> str:
207209
return "".join(["+" + e for e in self.extensions])
208210

209211
def __repr__(self):
210-
return f"TOSA-{self.version}{self._get_profiles_string()}{self._get_profiles_string()}"
212+
extensions = self._get_extensions_string()
213+
if self.level_8k:
214+
extensions += "+8k"
215+
return f"TOSA-{self.version}{self._get_profiles_string()}{extensions}"
211216

212217
def __hash__(self) -> int:
213218
return hash(str(self.version) + self._get_profiles_string())

backends/cadence/aot/functions_hifi.yaml

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,11 @@
7171
kernels:
7272
- arg_meta: null
7373
kernel_name: cadence::impl::HiFi::full_out
74-
74+
7575
- op: gt.Scalar_out
7676
kernels:
7777
- arg_meta: null
78-
kernel_name: torch::executor::gt_scalar_out
78+
kernel_name: torch::executor::gt_scalar_out
7979

8080
- op: gelu.out
8181
kernels:
@@ -100,7 +100,7 @@
100100
- op: mean.out
101101
kernels:
102102
- arg_meta: null
103-
kernel_name: cadence::impl::HiFi::mean_dim_out
103+
kernel_name: cadence::impl::HiFi::mean_dim_out
104104

105105
- op: minimum.out
106106
kernels:
@@ -213,3 +213,13 @@
213213
kernels:
214214
- arg_meta: null
215215
kernel_name: cadence::impl::HiFi::quantized_linear_per_tensor_out
216+
217+
- func: cadence::quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
218+
kernels:
219+
- arg_meta: null
220+
kernel_name: cadence::impl::HiFi::quantized_fully_connected_out
221+
222+
- func: cadence::quantized_fully_connected.per_tensor_out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, int weight_zero_point, int out_multiplier, int out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)
223+
kernels:
224+
- arg_meta: null
225+
kernel_name: cadence::impl::HiFi::quantized_fully_connected_per_tensor_out

backends/cadence/hifi/operators/op_clamp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ namespace impl {
4848
namespace HiFi {
4949
namespace native {
5050

51-
Tensor& clamp_tensor_out(
51+
Tensor& clamp_Tensor_out(
5252
RuntimeContext& ctx,
5353
const Tensor& in,
5454
const executorch::aten::optional<Tensor>& min_opt,

0 commit comments

Comments
 (0)