Skip to content

Commit c2aa614

Browse files
authored
Migrate xnnpack/vulkan/boltnn pt2e from torch.ao to torchao
Differential Revision: D75492104 Pull Request resolved: #11363
1 parent 00ab9f6 commit c2aa614

File tree

15 files changed

+120
-217
lines changed

15 files changed

+120
-217
lines changed

.lintrunner.toml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -386,15 +386,9 @@ exclude_patterns = [
386386
"third-party/**",
387387
# TODO: remove exceptions as we migrate
388388
# backends
389-
"backends/vulkan/quantizer/**",
390-
"backends/vulkan/test/**",
391-
"backends/xnnpack/quantizer/**",
392-
"backends/xnnpack/test/**",
393-
"exir/tests/test_passes.py",
394-
"extension/llm/export/builder.py",
395-
"extension/llm/export/quantizer_lib.py",
396389
"exir/tests/test_memory_planning.py",
397390
"exir/backend/test/demos/test_xnnpack_qnnpack.py",
391+
"backends/xnnpack/test/test_xnnpack_utils.py",
398392
]
399393

400394
command = [

backends/vulkan/quantizer/vulkan_quantizer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616
_convert_scalars_to_attrs,
1717
OP_TO_ANNOTATOR,
1818
propagate_annotation,
19-
QuantizationConfig,
2019
)
21-
from torch.ao.quantization.observer import PerChannelMinMaxObserver
22-
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
2320
from torch.fx import Node
21+
from torchao.quantization.pt2e import PerChannelMinMaxObserver
22+
from torchao.quantization.pt2e.quantizer import (
23+
QuantizationConfig,
24+
QuantizationSpec,
25+
Quantizer,
26+
)
2427

2528

2629
__all__ = [

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@
2323
EdgeProgramManager,
2424
ExecutorchProgramManager,
2525
)
26-
27-
from torch.ao.quantization.quantizer import Quantizer
2826
from torch.export import Dim, export, export_for_training, ExportedProgram
2927

3028
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
3129

30+
from torchao.quantization.pt2e.quantizer import Quantizer
31+
3232
ctypes.CDLL("libvulkan.so.1")
3333

3434

backends/vulkan/test/test_vulkan_passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from executorch.exir.backend.canonical_partitioners.config_partitioner import (
1717
format_target_name,
1818
)
19-
from torch.ao.quantization.quantizer import Quantizer
2019

2120
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
21+
from torchao.quantization.pt2e.quantizer import Quantizer
2222

2323
###################
2424
## Common Models ##

backends/xnnpack/partition/config/quant_affine_configs.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,33 +33,24 @@ class QuantizeAffineConfig(QDQAffineConfigs):
3333
target_name = "quantize_affine.default"
3434

3535
def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
36-
try:
37-
import torchao.quantization.quant_primitives # noqa
36+
import torchao.quantization.quant_primitives # noqa
3837

39-
return torch.ops.torchao.quantize_affine.default
40-
except:
41-
return None
38+
return torch.ops.torchao.quantize_affine.default
4239

4340

4441
class DeQuantizeAffineConfig(QDQAffineConfigs):
4542
target_name = "dequantize_affine.default"
4643

4744
def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
48-
try:
49-
import torchao.quantization.quant_primitives # noqa
45+
import torchao.quantization.quant_primitives # noqa
5046

51-
return torch.ops.torchao.dequantize_affine.default
52-
except:
53-
return None
47+
return torch.ops.torchao.dequantize_affine.default
5448

5549

5650
class ChooseQParamsAffineConfig(QDQAffineConfigs):
5751
target_name = "choose_qparams_affine.default"
5852

5953
def get_original_aten(self) -> Optional[torch._ops.OpOverload]:
60-
try:
61-
import torchao.quantization.quant_primitives # noqa
54+
import torchao.quantization.quant_primitives # noqa
6255

63-
return torch.ops.torchao.choose_qparams_affine.default
64-
except:
65-
return None
56+
return torch.ops.torchao.choose_qparams_affine.default

backends/xnnpack/quantizer/xnnpack_quantizer.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,31 @@
1212
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import (
1313
_convert_scalars_to_attrs,
1414
OP_TO_ANNOTATOR,
15-
OperatorConfig,
16-
OperatorPatternType,
1715
propagate_annotation,
18-
QuantizationConfig,
1916
)
20-
from torch.ao.quantization.fake_quantize import (
17+
from torchao.quantization.pt2e import (
2118
FakeQuantize,
2219
FusedMovingAvgObsFakeQuantize,
23-
)
24-
from torch.ao.quantization.observer import (
2520
HistogramObserver,
2621
MinMaxObserver,
2722
MovingAverageMinMaxObserver,
2823
MovingAveragePerChannelMinMaxObserver,
2924
PerChannelMinMaxObserver,
3025
PlaceholderObserver,
3126
)
32-
from torch.ao.quantization.quantizer import QuantizationSpec, Quantizer
33-
from torch.ao.quantization.quantizer.utils import _get_module_name_filter
27+
from torchao.quantization.pt2e.quantizer import (
28+
get_module_name_filter,
29+
OperatorConfig,
30+
OperatorPatternType,
31+
QuantizationConfig,
32+
QuantizationSpec,
33+
Quantizer,
34+
)
3435

3536

3637
if TYPE_CHECKING:
37-
from torch.ao.quantization.qconfig import _ObserverOrFakeQuantizeConstructor
3838
from torch.fx import Node
39+
from torchao.quantization.pt2e import ObserverOrFakeQuantizeConstructor
3940

4041

4142
__all__ = [
@@ -140,7 +141,7 @@ def get_symmetric_quantization_config(
140141
weight_qscheme = (
141142
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
142143
)
143-
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
144+
weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = (
144145
MinMaxObserver
145146
)
146147
if is_qat:
@@ -228,7 +229,7 @@ def _get_not_module_type_or_name_filter(
228229
tp_list: list[Callable], module_name_list: list[str]
229230
) -> Callable[[Node], bool]:
230231
module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
231-
module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list]
232+
module_name_list_filters = [get_module_name_filter(m) for m in module_name_list]
232233

233234
def not_module_type_or_name_filter(n: Node) -> bool:
234235
return not any(f(n) for f in module_type_filters + module_name_list_filters)
@@ -421,7 +422,7 @@ def _annotate_for_quantization_config(
421422
module_name_list = list(self.module_name_config.keys())
422423
for module_name, config in self.module_name_config.items():
423424
self._annotate_all_patterns(
424-
model, config, _get_module_name_filter(module_name)
425+
model, config, get_module_name_filter(module_name)
425426
)
426427

427428
tp_list = list(self.module_type_config.keys())

backends/xnnpack/quantizer/xnnpack_quantizer_utils.py

Lines changed: 27 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,43 @@
11
# mypy: allow-untyped-defs
22
import itertools
3-
import typing
4-
from dataclasses import dataclass
5-
from typing import Callable, NamedTuple, Optional
3+
from typing import Callable, Optional
64

75
import torch
86
import torch.nn.functional as F
97
from executorch.backends.xnnpack.utils.utils import is_depthwise_conv
108
from torch._subclasses import FakeTensor
11-
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
12-
from torch.ao.quantization.pt2e.export_utils import _WrapperModule
13-
from torch.ao.quantization.pt2e.utils import (
14-
_get_aten_graph_module_for_pattern,
15-
_is_conv_node,
16-
_is_conv_transpose_node,
9+
from torch.fx import Node
10+
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
11+
SubgraphMatcherWithNameNodeMap,
1712
)
18-
from torch.ao.quantization.quantizer import (
13+
from torchao.quantization.pt2e import WrapperModule
14+
from torchao.quantization.pt2e.graph_utils import get_source_partitions
15+
from torchao.quantization.pt2e.quantizer import (
16+
annotate_input_qspec_map,
17+
annotate_output_qspec,
18+
get_bias_qspec,
19+
get_input_act_qspec,
20+
get_output_act_qspec,
21+
get_weight_qspec,
22+
OperatorConfig,
23+
OperatorPatternType,
1924
QuantizationAnnotation,
25+
QuantizationConfig,
2026
QuantizationSpec,
2127
SharedQuantizationSpec,
2228
)
23-
from torch.ao.quantization.quantizer.utils import (
24-
_annotate_input_qspec_map,
25-
_annotate_output_qspec,
26-
)
27-
from torch.fx import Node
28-
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
29-
SubgraphMatcherWithNameNodeMap,
29+
from torchao.quantization.pt2e.utils import (
30+
_get_aten_graph_module_for_pattern,
31+
_is_conv_node,
32+
_is_conv_transpose_node,
33+
get_new_attr_name_with_prefix,
3034
)
31-
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
3235

3336
__all__ = [
3437
"OperatorConfig",
3538
"OperatorPatternType",
3639
"QuantizationConfig",
40+
"QuantizationSpec",
3741
"get_input_act_qspec",
3842
"get_output_act_qspec",
3943
"get_weight_qspec",
@@ -43,23 +47,6 @@
4347
]
4448

4549

46-
# In the absence of better name, just winging it with QuantizationConfig
47-
@dataclass(eq=True, frozen=True)
48-
class QuantizationConfig:
49-
input_activation: Optional[QuantizationSpec]
50-
output_activation: Optional[QuantizationSpec]
51-
weight: Optional[QuantizationSpec]
52-
bias: Optional[QuantizationSpec]
53-
# TODO: remove, since we can use observer_or_fake_quant_ctr to express this
54-
is_qat: bool = False
55-
56-
57-
# Use Annotated because list[Callable].__module__ is read-only.
58-
OperatorPatternType = typing.Annotated[list[Callable], None]
59-
OperatorPatternType.__module__ = (
60-
"executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils"
61-
)
62-
6350
AnnotatorType = Callable[
6451
[
6552
torch.fx.GraphModule,
@@ -78,19 +65,6 @@ def decorator(annotator: AnnotatorType) -> None:
7865
return decorator
7966

8067

81-
class OperatorConfig(NamedTuple):
82-
# fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]]
83-
# Basically we are mapping a quantization config to some list of patterns.
84-
# a pattern is defined as a list of nn module, function or builtin function names
85-
# e.g. [nn.Conv2d, torch.relu, torch.add]
86-
# We have not resolved whether fusion can be considered internal details of the
87-
# quantizer hence it does not need communication to user.
88-
# Note this pattern is not really informative since it does not really
89-
# tell us the graph structure resulting from the list of ops.
90-
config: QuantizationConfig
91-
operators: list[OperatorPatternType]
92-
93-
9468
def is_relu_node(node: Node) -> bool:
9569
"""
9670
Check if a given node is a relu node
@@ -124,63 +98,6 @@ def _mark_nodes_as_annotated(nodes: list[Node]):
12498
node.meta["quantization_annotation"]._annotated = True
12599

126100

127-
def get_input_act_qspec(quantization_config: Optional[QuantizationConfig]):
128-
if quantization_config is None:
129-
return None
130-
if quantization_config.input_activation is None:
131-
return None
132-
quantization_spec: QuantizationSpec = quantization_config.input_activation
133-
assert quantization_spec.qscheme in [
134-
torch.per_tensor_affine,
135-
torch.per_tensor_symmetric,
136-
]
137-
return quantization_spec
138-
139-
140-
def get_output_act_qspec(quantization_config: Optional[QuantizationConfig]):
141-
if quantization_config is None:
142-
return None
143-
if quantization_config.output_activation is None:
144-
return None
145-
quantization_spec: QuantizationSpec = quantization_config.output_activation
146-
assert quantization_spec.qscheme in [
147-
torch.per_tensor_affine,
148-
torch.per_tensor_symmetric,
149-
]
150-
return quantization_spec
151-
152-
153-
def get_weight_qspec(quantization_config: Optional[QuantizationConfig]):
154-
if quantization_config is None:
155-
return None
156-
assert quantization_config is not None
157-
if quantization_config.weight is None:
158-
return None
159-
quantization_spec: QuantizationSpec = quantization_config.weight
160-
if quantization_spec.qscheme not in [
161-
torch.per_tensor_symmetric,
162-
torch.per_channel_symmetric,
163-
None,
164-
]:
165-
raise ValueError(
166-
f"Unsupported quantization_spec {quantization_spec} for weight"
167-
)
168-
return quantization_spec
169-
170-
171-
def get_bias_qspec(quantization_config: Optional[QuantizationConfig]):
172-
if quantization_config is None:
173-
return None
174-
assert quantization_config is not None
175-
if quantization_config.bias is None:
176-
return None
177-
quantization_spec: QuantizationSpec = quantization_config.bias
178-
assert (
179-
quantization_spec.dtype == torch.float
180-
), "Only float dtype for bias is supported for bias right now"
181-
return quantization_spec
182-
183-
184101
@register_annotator("linear")
185102
def _annotate_linear(
186103
gm: torch.fx.GraphModule,
@@ -204,25 +121,25 @@ def _annotate_linear(
204121
bias_node = node.args[2]
205122

206123
if _is_annotated([node]) is False: # type: ignore[list-item]
207-
_annotate_input_qspec_map(
124+
annotate_input_qspec_map(
208125
node,
209126
act_node,
210127
input_act_qspec,
211128
)
212-
_annotate_input_qspec_map(
129+
annotate_input_qspec_map(
213130
node,
214131
weight_node,
215132
weight_qspec,
216133
)
217134
nodes_to_mark_annotated = [node, weight_node]
218135
if bias_node:
219-
_annotate_input_qspec_map(
136+
annotate_input_qspec_map(
220137
node,
221138
bias_node,
222139
bias_qspec,
223140
)
224141
nodes_to_mark_annotated.append(bias_node)
225-
_annotate_output_qspec(node, output_act_qspec)
142+
annotate_output_qspec(node, output_act_qspec)
226143
_mark_nodes_as_annotated(nodes_to_mark_annotated)
227144
annotated_partitions.append(nodes_to_mark_annotated)
228145

@@ -572,7 +489,7 @@ def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
572489
"output": output,
573490
}
574491

575-
return _WrapperModule(_conv_bn)
492+
return WrapperModule(_conv_bn)
576493

577494
# Needed for matching, otherwise the matches gets filtered out due to unused
578495
# nodes returned by batch norm

0 commit comments

Comments
 (0)