Skip to content

Commit 1904fa8

Browse files
committed
init
1 parent 2bc6207 commit 1904fa8

File tree

14 files changed

+60
-64
lines changed

14 files changed

+60
-64
lines changed

.lintrunner.toml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -391,11 +391,6 @@ exclude_patterns = [
391391
"backends/vulkan/quantizer/**",
392392
"backends/vulkan/test/**",
393393
"backends/cadence/aot/quantizer/**",
394-
<<<<<<< HEAD
395-
=======
396-
"backends/qualcomm/quantizer/**",
397-
"examples/qualcomm/**",
398-
>>>>>>> 362501568 (up)
399394
"backends/xnnpack/quantizer/**",
400395
"backends/xnnpack/test/**",
401396
"exir/tests/test_passes.py",

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ def get_to_edge_transform_passes(
131131
from executorch.backends.qualcomm._passes import utils
132132
from executorch.exir.dialects._ops import ops as exir_ops
133133

134-
utils.q_ops.add(exir_ops.edge.pt2e_quant.quantize_affine.default)
135-
utils.dq_ops.add(exir_ops.edge.pt2e_quant.dequantize_affine.default)
134+
utils.q_ops.add(exir_ops.edge.torchao.quantize_affine.default)
135+
utils.dq_ops.add(exir_ops.edge.torchao.dequantize_affine.default)
136136

137137
passes_job = (
138138
passes_job if passes_job is not None else get_capture_program_passes()

backends/qualcomm/builders/node_visitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,8 @@ def get_quant_encoding_conf(
254254
)
255255
# TODO: refactor this when target could be correctly detected
256256
per_block_encoding = {
257-
exir_ops.edge.pt2e_quant.quantize_affine.default,
258-
exir_ops.edge.pt2e_quant.dequantize_affine.default,
257+
exir_ops.edge.torchao.quantize_affine.default,
258+
exir_ops.edge.torchao.dequantize_affine.default,
259259
}
260260
if quant_attrs[QCOM_ENCODING] in per_block_encoding:
261261
return self.make_qnn_per_block_config(node, quant_attrs)

backends/qualcomm/partition/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_skip_decomp_table() -> List[torch._ops.OperatorBase]:
5757
torch.ops.aten.upsample_bicubic2d.vec,
5858
# This request is ignored because it is in a blocklist. Refer to exir/program/_program.py
5959
torch.ops.aten.unbind.int,
60-
torch.ops.pt2e_quant.quantize_affine.default,
61-
torch.ops.pt2e_quant.dequantize_affine.default,
60+
torch.ops.torchao.quantize_affine.default,
61+
torch.ops.torchao.dequantize_affine.default,
6262
]
6363
return do_not_decompose

backends/qualcomm/quantizer/annotators.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,17 @@
1212
from torch._ops import OpOverload
1313

1414
from torch._subclasses import FakeTensor
15-
from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize
15+
from torch.fx import Node
1616

17-
from torch.ao.quantization.observer import FixedQParamsObserver
18-
from torch.ao.quantization.quantizer import (
17+
from torchao.quantization.pt2e import FixedQParamsFakeQuantize, FixedQParamsObserver
18+
from torchao.quantization.pt2e.quantizer import (
19+
annotate_input_qspec_map,
20+
annotate_output_qspec,
1921
DerivedQuantizationSpec,
2022
QuantizationAnnotation,
2123
QuantizationSpec,
2224
SharedQuantizationSpec,
2325
)
24-
from torch.ao.quantization.quantizer.utils import (
25-
_annotate_input_qspec_map,
26-
_annotate_output_qspec,
27-
)
28-
from torch.fx import Node
2926

3027
from .qconfig import (
3128
get_16a16w_qnn_ptq_config,
@@ -618,19 +615,19 @@ def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> No
618615
return
619616

620617
# TODO current only support 16a16w
621-
_annotate_input_qspec_map(
618+
annotate_input_qspec_map(
622619
node,
623620
act_node,
624621
quantization_config.input_activation,
625622
)
626623

627-
_annotate_input_qspec_map(
624+
annotate_input_qspec_map(
628625
node,
629626
weight_node,
630627
quantization_config.input_activation,
631628
)
632629
nodes_to_mark_annotated = [node]
633-
_annotate_output_qspec(node, quantization_config.output_activation)
630+
annotate_output_qspec(node, quantization_config.output_activation)
634631
_mark_nodes_as_annotated(nodes_to_mark_annotated)
635632

636633

@@ -819,25 +816,25 @@ def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) ->
819816
if _is_annotated([node]):
820817
return
821818

822-
_annotate_input_qspec_map(
819+
annotate_input_qspec_map(
823820
node,
824821
act_node,
825822
quantization_config.input_activation,
826823
)
827-
_annotate_input_qspec_map(
824+
annotate_input_qspec_map(
828825
node,
829826
weight_node,
830827
quantization_config.weight,
831828
)
832829
nodes_to_mark_annotated = [node, weight_node]
833830
if bias_node:
834-
_annotate_input_qspec_map(
831+
annotate_input_qspec_map(
835832
node,
836833
bias_node,
837834
quantization_config.bias,
838835
)
839836
nodes_to_mark_annotated.append(bias_node)
840-
_annotate_output_qspec(node, quantization_config.output_activation)
837+
annotate_output_qspec(node, quantization_config.output_activation)
841838
_mark_nodes_as_annotated(nodes_to_mark_annotated)
842839

843840

@@ -1002,12 +999,12 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
1002999
if _is_annotated([node]):
10031000
return
10041001

1005-
_annotate_input_qspec_map(
1002+
annotate_input_qspec_map(
10061003
node,
10071004
act_node,
10081005
quantization_config.input_activation,
10091006
)
1010-
_annotate_input_qspec_map(
1007+
annotate_input_qspec_map(
10111008
node,
10121009
weight_node,
10131010
quantization_config.weight,
@@ -1018,9 +1015,9 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None
10181015
bias_config = quantization_config.bias(node)
10191016
else:
10201017
bias_config = quantization_config.bias
1021-
_annotate_input_qspec_map(node, bias_node, bias_config)
1018+
annotate_input_qspec_map(node, bias_node, bias_config)
10221019
nodes_to_mark_annotated.append(bias_node)
1023-
_annotate_output_qspec(node, quantization_config.output_activation)
1020+
annotate_output_qspec(node, quantization_config.output_activation)
10241021
_mark_nodes_as_annotated(nodes_to_mark_annotated)
10251022

10261023
# We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack.
@@ -1038,29 +1035,29 @@ def annotate_batch_and_instance_norm(
10381035
return
10391036

10401037
annotated_args = [act]
1041-
_annotate_input_qspec_map(
1038+
annotate_input_qspec_map(
10421039
node,
10431040
act,
10441041
quantization_config.input_activation,
10451042
)
10461043
# QNN requires uint8 instead of int8 in 'weight' config
10471044
if weight is not None:
1048-
_annotate_input_qspec_map(
1045+
annotate_input_qspec_map(
10491046
node,
10501047
weight,
10511048
quantization_config.input_activation,
10521049
)
10531050
annotated_args.append(weight)
10541051

10551052
if bias is not None:
1056-
_annotate_input_qspec_map(
1053+
annotate_input_qspec_map(
10571054
node,
10581055
bias,
10591056
quantization_config.bias,
10601057
)
10611058
annotated_args.append(bias)
10621059

1063-
_annotate_output_qspec(node, quantization_config.output_activation)
1060+
annotate_output_qspec(node, quantization_config.output_activation)
10641061
_mark_nodes_as_annotated([node, *annotated_args])
10651062

10661063

@@ -1070,7 +1067,7 @@ def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> Non
10701067
return
10711068

10721069
if _is_float_tensor(node):
1073-
_annotate_output_qspec(node, quantization_config.output_activation)
1070+
annotate_output_qspec(node, quantization_config.output_activation)
10741071
_mark_nodes_as_annotated([node])
10751072

10761073

@@ -1086,32 +1083,32 @@ def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) ->
10861083
return
10871084
input_act_qspec = quantization_config.input_activation
10881085

1089-
_annotate_input_qspec_map(
1086+
annotate_input_qspec_map(
10901087
node,
10911088
act_node,
10921089
input_act_qspec,
10931090
)
10941091
if input_act_qspec.dtype == torch.int32:
1095-
_annotate_input_qspec_map(
1092+
annotate_input_qspec_map(
10961093
node,
10971094
weight_node,
10981095
get_16a16w_qnn_ptq_config().weight,
10991096
)
11001097
else:
1101-
_annotate_input_qspec_map(
1098+
annotate_input_qspec_map(
11021099
node,
11031100
weight_node,
11041101
input_act_qspec,
11051102
)
11061103
nodes_to_mark_annotated = [node, weight_node]
11071104
if bias_node:
1108-
_annotate_input_qspec_map(
1105+
annotate_input_qspec_map(
11091106
node,
11101107
bias_node,
11111108
quantization_config.bias,
11121109
)
11131110
nodes_to_mark_annotated.append(bias_node)
1114-
_annotate_output_qspec(node, quantization_config.output_activation)
1111+
annotate_output_qspec(node, quantization_config.output_activation)
11151112
_mark_nodes_as_annotated(nodes_to_mark_annotated)
11161113

11171114

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
QuantizationConfig,
1818
)
1919
from executorch.exir.dialects._ops import ops as exir_ops
20-
from torch.ao.quantization.observer import FixedQParamsObserver, MinMaxObserver
21-
from torch.ao.quantization.quantizer import (
20+
from torch.fx import Node
21+
from torchao.quantization.pt2e import FixedQParamsObserver, MinMaxObserver
22+
from torchao.quantization.pt2e.quantizer import (
2223
QuantizationAnnotation,
2324
QuantizationSpec,
2425
SharedQuantizationSpec,
2526
)
26-
from torch.fx import Node
2727

2828

2929
def annotate_mimi_decoder(gm: torch.fx.GraphModule):

backends/qualcomm/quantizer/observers/per_block_param_observer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from typing import Tuple
88

99
import torch
10-
from torch.ao.quantization.observer import MappingType, PerBlock
11-
from torch.ao.quantization.pt2e._affine_quantization import (
10+
from torchao.quantization.pt2e import MappingType, PerBlock
11+
from torchao.quantization.pt2e._affine_quantization import (
1212
_get_reduction_params,
1313
AffineQuantizedMinMaxObserver,
1414
choose_qparams_affine_with_min_max,

backends/qualcomm/quantizer/observers/per_channel_param_observer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import torch
8-
from torch.ao.quantization.observer import UniformQuantizationObserverBase
8+
from torchao.quantization.pt2e import UniformQuantizationObserverBase
99

1010

1111
# TODO move to torch/ao/quantization/observer.py.

backends/qualcomm/quantizer/qconfig.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,19 @@
77
PerBlockParamObserver,
88
)
99
from torch import Tensor
10-
from torch.ao.quantization.fake_quantize import (
10+
from torch.fx import Node
11+
from torchao.quantization.pt2e import (
1112
FakeQuantize,
1213
FusedMovingAvgObsFakeQuantize,
13-
)
14-
from torch.ao.quantization.observer import (
1514
MinMaxObserver,
1615
MovingAverageMinMaxObserver,
1716
MovingAveragePerChannelMinMaxObserver,
1817
PerChannelMinMaxObserver,
1918
)
20-
from torch.ao.quantization.quantizer import DerivedQuantizationSpec, QuantizationSpec
21-
from torch.fx import Node
19+
from torchao.quantization.pt2e.quantizer import (
20+
DerivedQuantizationSpec,
21+
QuantizationSpec,
22+
)
2223

2324

2425
@dataclass(eq=True)

backends/qualcomm/quantizer/quantizer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
from executorch.backends.qualcomm._passes.qnn_pass_manager import QnnPassManager
1313

1414
from torch._ops import OpOverload
15-
from torch.ao.quantization.quantizer import Quantizer
1615
from torch.fx import GraphModule
16+
from torchao.quantization.pt2e import UniformQuantizationObserverBase
17+
from torchao.quantization.pt2e.quantizer import Quantizer
1718

1819
from .annotators import OP_ANNOTATOR
1920

@@ -130,9 +131,7 @@ class ModuleQConfig:
130131
is_qat: bool = False
131132
is_conv_per_channel: bool = False
132133
is_linear_per_channel: bool = False
133-
act_observer: Optional[
134-
torch.ao.quantization.observer.UniformQuantizationObserverBase
135-
] = None
134+
act_observer: Optional[UniformQuantizationObserverBase] = None
136135

137136
def __post_init__(self):
138137
if (self.quant_dtype, self.is_qat) not in QUANT_CONFIG_DICT:

backends/qualcomm/tests/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import torch
17+
import torchao
1718
from executorch import exir
1819
from executorch.backends.qualcomm._passes.utils import dq_ops
1920
from executorch.backends.qualcomm.qnn_preprocess import QnnBackend
@@ -537,8 +538,8 @@ def get_qdq_module(
537538
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
538539
torch.ops.quantized_decomposed.quantize_per_channel.default,
539540
torch.ops.quantized_decomposed.dequantize_per_channel.default,
540-
torch.ops.pt2e_quant.quantize_affine.default,
541-
torch.ops.pt2e_quant.dequantize_affine.default,
541+
torch.ops.torchao.quantize_affine.default,
542+
torch.ops.torchao.dequantize_affine.default,
542543
}
543544
if not bypass_check:
544545
self.assertTrue(nodes.intersection(q_and_dq))
@@ -569,7 +570,7 @@ def get_prepared_qat_module(
569570
quantizer.set_submodule_qconfig_list(submodule_qconfig_list)
570571

571572
prepared = prepare_qat_pt2e(m, quantizer)
572-
return torch.ao.quantization.move_exported_model_to_train(prepared)
573+
return torchao.quantization.pt2e.move_exported_model_to_train(prepared)
573574

574575
def get_converted_sgd_trained_module(
575576
self,

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@
8181
from pytorch_tokenizers import get_tokenizer, TiktokenTokenizer
8282
from pytorch_tokenizers.llama2c import Llama2cTokenizer as SentencePieceTokenizer
8383

84-
from torch.ao.quantization.observer import MinMaxObserver
84+
from torchao.quantization.pt2e import MinMaxObserver
8585
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
8686

8787
sys.setrecursionlimit(4096)

examples/qualcomm/oss_scripts/moshi/mimi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from huggingface_hub import hf_hub_download
3838
from moshi.models import loaders
3939

40-
from torch.ao.quantization.observer import MinMaxObserver
40+
from torchao.quantization.pt2e import MinMaxObserver
4141

4242

4343
def seed_all(seed):

0 commit comments

Comments
 (0)