Skip to content

Commit caadd81

Browse files
authored
VulkanQuantizer for weight-only quantization on linear
Differential Revision: D61243540 Pull Request resolved: #4707
1 parent 35da5bf commit caadd81

File tree

2 files changed

+133
-0
lines changed

2 files changed

+133
-0
lines changed

backends/vulkan/quantizer/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
oncall("executorch")
4+
5+
python_library(
6+
name = "vulkan_quantizer",
7+
srcs = [
8+
"vulkan_quantizer.py",
9+
],
10+
deps = [
11+
"//caffe2:torch",
12+
],
13+
)
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from __future__ import annotations
10+
11+
import functools
12+
from typing import Any, Callable, Dict, Optional
13+
14+
import torch
15+
from torch.ao.quantization.observer import MinMaxObserver, PerChannelMinMaxObserver
16+
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
17+
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
18+
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
19+
_convert_scalars_to_attrs,
20+
OP_TO_ANNOTATOR,
21+
propagate_annotation,
22+
QuantizationConfig,
23+
)
24+
from torch.fx import Node
25+
26+
27+
__all__ = [
28+
"VulkanQuantizer",
29+
"get_weight_quantization_config",
30+
]
31+
32+
33+
@functools.lru_cache
34+
def get_weight_quantization_config(
35+
is_per_channel: bool = True,
36+
weight_qmin: int = -128,
37+
weight_qmax: int = 127,
38+
) -> QuantizationConfig:
39+
40+
weight_qscheme = (
41+
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
42+
)
43+
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
44+
PerChannelMinMaxObserver if is_per_channel else MinMaxObserver
45+
)
46+
extra_args: Dict[str, Any] = {"eps": 2**-12}
47+
48+
weight_quantization_spec = QuantizationSpec(
49+
dtype=torch.int8,
50+
quant_min=weight_qmin,
51+
quant_max=weight_qmax,
52+
qscheme=weight_qscheme,
53+
ch_axis=0,
54+
is_dynamic=False,
55+
observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr.with_args(
56+
**extra_args
57+
),
58+
)
59+
60+
quantization_config = QuantizationConfig(
61+
input_activation=None,
62+
output_activation=None,
63+
weight=weight_quantization_spec,
64+
bias=None,
65+
is_qat=False,
66+
)
67+
return quantization_config
68+
69+
70+
_SUPPORTED_OPS = [
71+
"linear",
72+
]
73+
74+
75+
class VulkanQuantizer(Quantizer):
76+
77+
def __init__(self) -> None:
78+
super().__init__()
79+
self.global_config: Optional[QuantizationConfig] = None
80+
81+
def set_global(self, quantization_config: QuantizationConfig) -> VulkanQuantizer:
82+
self.global_config = quantization_config
83+
return self
84+
85+
def transform_for_annotation(
86+
self, model: torch.fx.GraphModule
87+
) -> torch.fx.GraphModule:
88+
"""Transforms scalar values to tensor attributes"""
89+
return _convert_scalars_to_attrs(model)
90+
91+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
92+
# currently only support static quant on Vulkan
93+
model = self._annotate_for_static_quantization_config(model)
94+
propagate_annotation(model)
95+
return model
96+
97+
def _annotate_all_static_patterns(
98+
self,
99+
model: torch.fx.GraphModule,
100+
quantization_config: Optional[QuantizationConfig],
101+
filter_fn: Optional[Callable[[Node], bool]] = None,
102+
) -> torch.fx.GraphModule:
103+
if quantization_config is None:
104+
return model
105+
106+
for op in _SUPPORTED_OPS:
107+
OP_TO_ANNOTATOR[op](model, quantization_config, filter_fn)
108+
return model
109+
110+
def _annotate_for_static_quantization_config(
111+
self, model: torch.fx.GraphModule
112+
) -> torch.fx.GraphModule:
113+
self._annotate_all_static_patterns(
114+
model,
115+
self.global_config,
116+
)
117+
return model
118+
119+
def validate(self, model: torch.fx.GraphModule) -> None:
120+
pass

0 commit comments

Comments
 (0)