Skip to content

Commit 36cd162

Browse files
committed
Add Vulkan Quantizer to Llama export lib
Pull Request resolved: #6169 TSIA. Note that only 8 bit weight only quantization is supported for now since `VulkanQuantizer` does not support 4 bit weight only quantization at the moment. ghstack-source-id: 247613963 @exported-using-ghexport Differential Revision: [D64249615](https://our.internmc.facebook.com/intern/diff/D64249615/)
1 parent a56d121 commit 36cd162

File tree

3 files changed

+28
-0
lines changed

3 files changed

+28
-0
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
get_pt2e_quantization_params,
4242
get_pt2e_quantizers,
4343
get_qnn_quantizer,
44+
get_vulkan_quantizer,
4445
)
4546
from executorch.util.activation_memory_profiler import generate_memory_trace
4647

@@ -147,6 +148,7 @@ def build_args_parser() -> argparse.ArgumentParser:
147148
"coreml_8a_c4w",
148149
"coreml_baseline_8a_c8w",
149150
"coreml_baseline_8a_c4w",
151+
"vulkan_8w",
150152
],
151153
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
152154
)
@@ -548,6 +550,12 @@ def get_quantizer_and_quant_params(args):
548550
assert len(quantizers) == 0, "Should not enable both xnnpack / qnn and coreml"
549551
coreml_quantizer = get_coreml_quantizer(args.pt2e_quantize)
550552
quantizers.append(coreml_quantizer)
553+
if args.vulkan and args.pt2e_quantize:
554+
assert (
555+
len(quantizers) == 0
556+
), "Should not enable both vulkan and other quantizers"
557+
vulkan_quantizer = get_vulkan_quantizer(args.pt2e_quantize)
558+
quantizers.append(vulkan_quantizer)
551559
logging.info(f"Applying quantizers: {quantizers}")
552560
return pt2e_quant_params, quantizers, quant_dtype
553561

extension/llm/export/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ runtime.python_library(
3131
"//executorch/backends/qualcomm/quantizer:quantizer",
3232
"//executorch/backends/transforms:duplicate_dynamic_quant_chain",
3333
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
34+
"//executorch/backends/vulkan/quantizer:vulkan_quantizer",
3435
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
3536
"//executorch/exir:lib",
3637
"//executorch/exir/backend:backend_details",

extension/llm/export/quantizer_lib.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,22 @@ def get_coreml_quantizer(pt2e_quantize: str):
260260
raise ValueError(f"Unsupported Core ML quantizer specification {pt2e_quantize}")
261261

262262
return quantizer
263+
264+
265+
def get_vulkan_quantizer(pt2e_quantize: str):
266+
from executorch.backends.vulkan.quantizer.vulkan_quantizer import (
267+
get_weight_quantization_config,
268+
VulkanQuantizer,
269+
)
270+
271+
if pt2e_quantize == "vulkan_8w":
272+
config = get_weight_quantization_config(
273+
is_per_channel=True,
274+
weight_qmin=-128,
275+
weight_qmax=127,
276+
)
277+
else:
278+
raise ValueError(f"Unsupported Vulkan quantizer specification {pt2e_quantize}")
279+
280+
quantizer = VulkanQuantizer().set_global(config)
281+
return quantizer

0 commit comments

Comments
 (0)