Skip to content

Commit 1282332

Browse files
nathanaelseefacebook-github-bot
authored andcommitted
VulkanQuantizer for weight-only quantization on linear
Summary: Using XNNPACKQuantizer as a base. VulkanQuantizer only annotates for 8-bit weight-only static quantization on linear nodes for now, as we only currently implement 8-bit weight quantized linear in the form of weight_int8packed_mm. Differential Revision: D61243540
1 parent ba3448c commit 1282332

File tree

2 files changed

+135
-0
lines changed

2 files changed

+135
-0
lines changed

backends/vulkan/quantizer/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
oncall("executorch")
4+
5+
python_library(
6+
name = "vulkan_quantizer",
7+
srcs = [
8+
"vulkan_quantizer.py",
9+
],
10+
deps = [
11+
"//caffe2:torch",
12+
],
13+
)
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from __future__ import annotations
10+
11+
import functools
12+
from typing import Any, Callable, Dict, Optional
13+
14+
import torch
15+
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
16+
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
17+
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
18+
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
19+
_convert_scalars_to_attrs,
20+
OP_TO_ANNOTATOR,
21+
propagate_annotation,
22+
QuantizationConfig,
23+
)
24+
from torch.fx import Node
25+
26+
27+
__all__ = [
28+
"VulkanQuantizer",
29+
"get_static_quantization_config",
30+
]
31+
32+
33+
@functools.lru_cache
34+
def get_static_quantization_config(
35+
is_per_channel: bool = True,
36+
weight_qmin: int = -127,
37+
weight_qmax: int = 127,
38+
) -> QuantizationConfig:
39+
extra_args: Dict[str, Any] = {"eps": 2**-12}
40+
41+
weight_qscheme = (
42+
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
43+
)
44+
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
45+
PerChannelMinMaxObserver if is_per_channel else MinMaxObserver
46+
)
47+
48+
extra_args: Dict[str, Any] = {"eps": 2**-12}
49+
weight_quantization_spec = QuantizationSpec(
50+
dtype=torch.int8,
51+
quant_min=weight_qmin,
52+
quant_max=weight_qmax,
53+
qscheme=weight_qscheme,
54+
ch_axis=0,
55+
is_dynamic=False,
56+
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
57+
**extra_args
58+
),
59+
)
60+
61+
# Only weight-only static quant supported
62+
quantization_config = QuantizationConfig(
63+
input_activation=None,
64+
output_activation=None,
65+
weight=weight_quantization_spec,
66+
bias=None,
67+
is_qat=False,
68+
)
69+
return quantization_config
70+
71+
72+
_SUPPORTED_OPS = [
73+
"linear",
74+
]
75+
76+
77+
class VulkanQuantizer(Quantizer):
78+
79+
def __init__(self) -> None:
80+
super().__init__()
81+
self.global_config: Optional[QuantizationConfig] = None
82+
83+
def set_global(self, quantization_config: QuantizationConfig) -> VulkanQuantizer:
84+
self.global_config = quantization_config
85+
return self
86+
87+
def transform_for_annotation(
88+
self, model: torch.fx.GraphModule
89+
) -> torch.fx.GraphModule:
90+
"""Transforms scalar values to tensor attributes"""
91+
return _convert_scalars_to_attrs(model)
92+
93+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
94+
# currently only support static quant on Vulkan
95+
model = self._annotate_for_static_quantization_config(model)
96+
propagate_annotation(model)
97+
return model
98+
99+
def _annotate_all_static_patterns(
100+
self,
101+
model: torch.fx.GraphModule,
102+
quantization_config: Optional[QuantizationConfig],
103+
filter_fn: Optional[Callable[[Node], bool]] = None,
104+
) -> torch.fx.GraphModule:
105+
if quantization_config is None:
106+
return model
107+
108+
for op in _SUPPORTED_OPS:
109+
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
110+
return model
111+
112+
def _annotate_for_static_quantization_config(
113+
self, model: torch.fx.GraphModule
114+
) -> torch.fx.GraphModule:
115+
self._annotate_all_static_patterns(
116+
model,
117+
self.global_config,
118+
)
119+
return model
120+
121+
def validate(self, model: torch.fx.GraphModule) -> None:
122+
pass

0 commit comments

Comments
 (0)