Skip to content

Commit 2f400ca

Browse files
committed
up
1 parent e21fa8d commit 2f400ca

File tree

12 files changed

+70
-73
lines changed

12 files changed

+70
-73
lines changed

backends/arm/quantizer/arm_quantizer.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,22 @@
3131
) # usort: skip
3232
from executorch.exir.backend.compile_spec_schema import CompileSpec
3333
from torch.fx import GraphModule, Node
34-
from torchao.quantization.pt2e import _ObserverOrFakeQuantizeConstructor
35-
from torchao.quantization.pt2e.fake_quantize import (
34+
from torchao.quantization.pt2e import (
3635
FakeQuantize,
3736
FusedMovingAvgObsFakeQuantize,
38-
)
39-
from torchao.quantization.pt2e.observer import (
4037
HistogramObserver,
4138
MinMaxObserver,
4239
MovingAverageMinMaxObserver,
4340
MovingAveragePerChannelMinMaxObserver,
41+
ObserverOrFakeQuantizeConstructor,
4442
PerChannelMinMaxObserver,
4543
PlaceholderObserver,
4644
)
47-
from torchao.quantization.pt2e.quantizer import QuantizationSpec, Quantizer
48-
from torchao.quantization.pt2e.quantizer.utils import (
49-
_annotate_input_qspec_map,
50-
_annotate_output_qspec,
45+
from torchao.quantization.pt2e.quantizer import (
46+
annotate_input_qspec_map,
47+
annotate_output_qspec,
48+
QuantizationSpec,
49+
Quantizer,
5150
)
5251

5352
__all__ = [
@@ -97,7 +96,7 @@ def get_symmetric_quantization_config(
9796
weight_qscheme = (
9897
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
9998
)
100-
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
99+
weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = (
101100
MinMaxObserver
102101
)
103102
if is_qat:
@@ -337,14 +336,14 @@ def _annotate_io(
337336
if is_annotated(node):
338337
continue
339338
if node.op == "placeholder" and len(node.users) > 0:
340-
_annotate_output_qspec(
339+
annotate_output_qspec(
341340
node,
342341
quantization_config.get_output_act_qspec(),
343342
)
344343
mark_node_as_annotated(node)
345344
if node.op == "output":
346345
parent = node.all_input_nodes[0]
347-
_annotate_input_qspec_map(
346+
annotate_input_qspec_map(
348347
node, parent, quantization_config.get_input_act_qspec()
349348
)
350349
mark_node_as_annotated(node)

backends/arm/quantizer/quantization_annotator.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,11 @@
1414
from executorch.backends.arm.tosa_utils import get_node_debug_info
1515
from torch.fx import Node
1616
from torchao.quantization.pt2e.quantizer import (
17+
annotate_input_qspec_map,
18+
annotate_output_qspec,
1719
QuantizationSpecBase,
1820
SharedQuantizationSpec,
1921
)
20-
from torchao.quantization.pt2e.quantizer.utils import (
21-
_annotate_input_qspec_map,
22-
_annotate_output_qspec,
23-
)
2422

2523
from .arm_quantizer_utils import (
2624
is_annotated,
@@ -121,7 +119,7 @@ def _annotate_input(node: Node, quant_property: _QuantProperty):
121119
strict=True,
122120
):
123121
assert isinstance(n_arg, Node)
124-
_annotate_input_qspec_map(node, n_arg, qspec)
122+
annotate_input_qspec_map(node, n_arg, qspec)
125123
if quant_property.mark_annotated:
126124
mark_node_as_annotated(n_arg) # type: ignore[attr-defined]
127125

@@ -132,7 +130,7 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
132130
assert not quant_property.optional
133131
assert quant_property.index == 0, "Only one output annotation supported currently"
134132

135-
_annotate_output_qspec(node, quant_property.qspec)
133+
annotate_output_qspec(node, quant_property.qspec)
136134

137135

138136
def _match_pattern(

backends/example/example_backend_delegate_passes/permute_memory_formats_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from executorch.exir.dialects._ops import ops as exir_ops
1212
from executorch.exir.dim_order_utils import get_dim_order
1313
from executorch.exir.pass_base import ExportPass, PassResult
14-
from torchao.quantization.pt2e.pt2e.graph_utils import find_sequential_partitions
14+
from torchao.quantization.pt2e import find_sequential_partitions
1515

1616

1717
class PermuteMemoryFormatsPass(ExportPass):

backends/example/example_partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from executorch.exir.graph_module import get_control_flow_submodules
2222
from torch.export import ExportedProgram
2323
from torch.fx.passes.operator_support import OperatorSupportBase
24-
from torchao.quantization.pt2e.pt2e.graph_utils import find_sequential_partitions
24+
from torchao.quantization.pt2e import find_sequential_partitions
2525

2626

2727
@final

backends/example/example_quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from executorch.backends.example.example_operators.ops import module_to_annotator
1212
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils import OperatorConfig
1313
from torch import fx
14+
from torchao.quantization.pt2e import find_sequential_partitions
1415
from torchao.quantization.pt2e.observer import HistogramObserver, MinMaxObserver
15-
from torchao.quantization.pt2e.pt2e.graph_utils import find_sequential_partitions
1616
from torchao.quantization.pt2e.quantizer import QuantizationSpec, Quantizer
1717

1818

backends/mediatek/quantizer/annotator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
SubgraphMatcherWithNameNodeMap,
1717
)
1818

19-
from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
20-
from torchao.quantization.pt2e.quantizer.utils import (
21-
_annotate_input_qspec_map,
22-
_annotate_output_qspec,
19+
from torchao.quantization.pt2e.quantizer import (
20+
annotate_input_qspec_map,
21+
annotate_output_qspec as _annotate_output_qspec,
22+
QuantizationAnnotation,
2323
)
2424

2525
from .qconfig import QuantizationConfig
@@ -108,7 +108,7 @@ def _annotate_fused_activation_pattern(
108108
torch.ops.aten.linear.default,
109109
]:
110110
weight_node = producer_node.args[1]
111-
_annotate_input_qspec_map(
111+
annotate_input_qspec_map(
112112
producer_node,
113113
weight_node,
114114
quant_config.weight,
@@ -201,7 +201,7 @@ def annotate_affine_ops(node: Node, quant_config: QuantizationConfig) -> None:
201201
return
202202

203203
weight_node = node.args[1]
204-
_annotate_input_qspec_map(
204+
annotate_input_qspec_map(
205205
node,
206206
weight_node,
207207
quant_config.weight,
@@ -260,5 +260,5 @@ def annotate_embedding_op(node: Node, quant_config: QuantizationConfig) -> None:
260260
return
261261

262262
wgt_node = node.args[0]
263-
_annotate_input_qspec_map(node, wgt_node, quant_config.activation)
263+
annotate_input_qspec_map(node, wgt_node, quant_config.activation)
264264
_mark_as_annotated([node])

backends/qualcomm/quantizer/annotators.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
SharedQuantizationSpec,
2424
)
2525
from torchao.quantization.pt2e.quantizer.utils import (
26-
_annotate_input_qspec_map,
27-
_annotate_output_qspec,
26+
annotate_input_qspec_map,
27+
annotate_output_qspec,
2828
)
2929

3030
from .qconfig import (
@@ -618,19 +618,19 @@ def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> No
618618
return
619619

620620
# TODO current only support 16a16w
621-
_annotate_input_qspec_map(
621+
annotate_input_qspec_map(
622622
node,
623623
act_node,
624624
quantization_config.input_activation,
625625
)
626626

627-
_annotate_input_qspec_map(
627+
annotate_input_qspec_map(
628628
node,
629629
weight_node,
630630
quantization_config.input_activation,
631631
)
632632
nodes_to_mark_annotated = [node]
633-
_annotate_output_qspec(node, quantization_config.output_activation)
633+
annotate_output_qspec(node, quantization_config.output_activation)
634634
_mark_nodes_as_annotated(nodes_to_mark_annotated)
635635

636636

@@ -819,25 +819,25 @@ def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) ->
819819
if _is_annotated([node]):
820820
return
821821

822-
_annotate_input_qspec_map(
822+
annotate_input_qspec_map(
823823
node,
824824
act_node,
825825
quantization_config.input_activation,
826826
)
827-
_annotate_input_qspec_map(
827+
annotate_input_qspec_map(
828828
node,
829829
weight_node,
830830
quantization_config.weight,
831831
)
832832
nodes_to_mark_annotated = [node, weight_node]
833833
if bias_node:
834-
_annotate_input_qspec_map(
834+
annotate_input_qspec_map(
835835
node,
836836
bias_node,
837837
quantization_config.bias,
838838
)
839839
nodes_to_mark_annotated.append(bias_node)
840-
_annotate_output_qspec(node, quantization_config.output_activation)
840+
annotate_output_qspec(node, quantization_config.output_activation)
841841
_mark_nodes_as_annotated(nodes_to_mark_annotated)
842842

843843

@@ -1002,12 +1002,12 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
10021002
if _is_annotated([node]):
10031003
return
10041004

1005-
_annotate_input_qspec_map(
1005+
annotate_input_qspec_map(
10061006
node,
10071007
act_node,
10081008
quantization_config.input_activation,
10091009
)
1010-
_annotate_input_qspec_map(
1010+
annotate_input_qspec_map(
10111011
node,
10121012
weight_node,
10131013
quantization_config.weight,
@@ -1018,9 +1018,9 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
10181018
bias_config = quantization_config.bias(node)
10191019
else:
10201020
bias_config = quantization_config.bias
1021-
_annotate_input_qspec_map(node, bias_node, bias_config)
1021+
annotate_input_qspec_map(node, bias_node, bias_config)
10221022
nodes_to_mark_annotated.append(bias_node)
1023-
_annotate_output_qspec(node, quantization_config.output_activation)
1023+
annotate_output_qspec(node, quantization_config.output_activation)
10241024
_mark_nodes_as_annotated(nodes_to_mark_annotated)
10251025

10261026
# We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack.
@@ -1038,29 +1038,29 @@ def annotate_batch_and_instance_norm(
10381038
return
10391039

10401040
annotated_args = [act]
1041-
_annotate_input_qspec_map(
1041+
annotate_input_qspec_map(
10421042
node,
10431043
act,
10441044
quantization_config.input_activation,
10451045
)
10461046
# QNN requires uint8 instead of int8 in 'weight' config
10471047
if weight is not None:
1048-
_annotate_input_qspec_map(
1048+
annotate_input_qspec_map(
10491049
node,
10501050
weight,
10511051
quantization_config.input_activation,
10521052
)
10531053
annotated_args.append(weight)
10541054

10551055
if bias is not None:
1056-
_annotate_input_qspec_map(
1056+
annotate_input_qspec_map(
10571057
node,
10581058
bias,
10591059
quantization_config.bias,
10601060
)
10611061
annotated_args.append(bias)
10621062

1063-
_annotate_output_qspec(node, quantization_config.output_activation)
1063+
annotate_output_qspec(node, quantization_config.output_activation)
10641064
_mark_nodes_as_annotated([node, *annotated_args])
10651065

10661066

@@ -1070,7 +1070,7 @@ def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> Non
10701070
return
10711071

10721072
if _is_float_tensor(node):
1073-
_annotate_output_qspec(node, quantization_config.output_activation)
1073+
annotate_output_qspec(node, quantization_config.output_activation)
10741074
_mark_nodes_as_annotated([node])
10751075

10761076

@@ -1086,32 +1086,32 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) ->
10861086
return
10871087
input_act_qspec = quantization_config.input_activation
10881088

1089-
_annotate_input_qspec_map(
1089+
annotate_input_qspec_map(
10901090
node,
10911091
act_node,
10921092
input_act_qspec,
10931093
)
10941094
if input_act_qspec.dtype == torch.int32:
1095-
_annotate_input_qspec_map(
1095+
annotate_input_qspec_map(
10961096
node,
10971097
weight_node,
10981098
get_16a16w_qnn_ptq_config().weight,
10991099
)
11001100
else:
1101-
_annotate_input_qspec_map(
1101+
annotate_input_qspec_map(
11021102
node,
11031103
weight_node,
11041104
input_act_qspec,
11051105
)
11061106
nodes_to_mark_annotated = [node, weight_node]
11071107
if bias_node:
1108-
_annotate_input_qspec_map(
1108+
annotate_input_qspec_map(
11091109
node,
11101110
bias_node,
11111111
quantization_config.bias,
11121112
)
11131113
nodes_to_mark_annotated.append(bias_node)
1114-
_annotate_output_qspec(node, quantization_config.output_activation)
1114+
annotate_output_qspec(node, quantization_config.output_activation)
11151115
_mark_nodes_as_annotated(nodes_to_mark_annotated)
11161116

11171117

backends/qualcomm/quantizer/observers/per_block_param_observer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
from typing import Tuple
88

99
import torch
10-
from torchao.quantization.pt2e.observer import MappingType, PerBlock
11-
from torchao.quantization.pt2e.pt2e._affine_quantization import (
10+
from torchao.quantization.pt2e._affine_quantization import (
1211
_get_reduction_params,
1312
AffineQuantizedMinMaxObserver,
1413
choose_qparams_affine_with_min_max,
1514
)
15+
from torchao.quantization.pt2e.observer import MappingType, PerBlock
1616

1717

1818
class PerBlockParamObserver(AffineQuantizedMinMaxObserver):

backends/transforms/duplicate_dynamic_quant_chain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
from torch.fx.node import map_arg
1313
from torch.fx.passes.infra.pass_base import PassBase, PassResult
1414

15-
from torchao.quantization.pt2e.pt2e.utils import (
15+
from torchao.quantization.pt2e.utils import (
1616
_filter_sym_size_users,
17-
_is_valid_annotation,
17+
_is_valid_annotation, # @nocommit not found
1818
)
1919

2020

backends/xnnpack/quantizer/xnnpack_quantizer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@
3030
PlaceholderObserver,
3131
)
3232
from torchao.quantization.pt2e.quantizer import QuantizationSpec, Quantizer
33-
from torchao.quantization.pt2e.quantizer.utils import _get_module_name_filter
33+
from torchao.quantization.pt2e.quantizer.utils import get_module_name_filter
3434

3535

3636
if TYPE_CHECKING:
3737
from torch.fx import Node
38-
from torchao.quantization.pt2e import _ObserverOrFakeQuantizeConstructor
38+
from torchao.quantization.pt2e import ObserverOrFakeQuantizeConstructor
3939

4040

4141
__all__ = [
@@ -140,7 +140,7 @@ def get_symmetric_quantization_config(
140140
weight_qscheme = (
141141
torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
142142
)
143-
weight_observer_or_fake_quant_ctr: _ObserverOrFakeQuantizeConstructor = (
143+
weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = (
144144
MinMaxObserver
145145
)
146146
if is_qat:
@@ -228,7 +228,7 @@ def _get_not_module_type_or_name_filter(
228228
tp_list: list[Callable], module_name_list: list[str]
229229
) -> Callable[[Node], bool]:
230230
module_type_filters = [_get_module_type_filter(tp) for tp in tp_list]
231-
module_name_list_filters = [_get_module_name_filter(m) for m in module_name_list]
231+
module_name_list_filters = [get_module_name_filter(m) for m in module_name_list]
232232

233233
def not_module_type_or_name_filter(n: Node) -> bool:
234234
return not any(f(n) for f in module_type_filters + module_name_list_filters)
@@ -421,7 +421,7 @@ def _annotate_for_quantization_config(
421421
module_name_list = list(self.module_name_config.keys())
422422
for module_name, config in self.module_name_config.items():
423423
self._annotate_all_patterns(
424-
model, config, _get_module_name_filter(module_name)
424+
model, config, get_module_name_filter(module_name)
425425
)
426426

427427
tp_list = list(self.module_type_config.keys())

0 commit comments

Comments
 (0)