Skip to content

Commit 6230f8f

Browse files
mcr229facebook-github-bot
authored andcommitted
quant params from static inputs (#573)
Summary: Pull Request resolved: #573 Since we allow tensor constants to be quantized inputs, we need to adjust the from_inputs api to search if this input is static or not. If it is static, then we take the first q node in get_attr --> q --> dq. If it is not static, then we just take the dq node to create the QuantParams object. In the past, we can take in static quant inputs only on weights and biases. Reviewed By: digantdesai Differential Revision: D49850149 fbshipit-source-id: 007d594977f144c0fa58b6db01ca63e52e40312d
1 parent 98af719 commit 6230f8f

File tree

8 files changed

+25
-13
lines changed

8 files changed

+25
-13
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def define_nodes_tensor_inputs_outputs(
423423
inp,
424424
xnn_graph,
425425
vals_to_ids,
426-
quant_params=QuantParams.from_inputs(inp),
426+
quant_params=QuantParams.from_inputs(inp, self._exported_program),
427427
convert_to_nhwc=convert_to_nhwc,
428428
)
429429
else:
@@ -434,7 +434,9 @@ def define_nodes_tensor_inputs_outputs(
434434
)
435435
# Define Input Node
436436
input_node = get_input_node(node, input_type_map.node_input)
437-
input_quant_params = QuantParams.from_inputs(input_node)
437+
input_quant_params = QuantParams.from_inputs(
438+
input_node, self._exported_program
439+
)
438440
self.define_tensor(
439441
input_node,
440442
xnn_graph,

backends/xnnpack/operators/op_add.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def define_node(
4343
input1,
4444
xnn_graph,
4545
vals_to_ids,
46-
quant_params=QuantParams.from_inputs(input1),
46+
quant_params=QuantParams.from_inputs(input1, self._exported_program),
4747
)
4848
input1_id = vals_to_ids[input1]
4949

@@ -53,7 +53,7 @@ def define_node(
5353
input2,
5454
xnn_graph,
5555
vals_to_ids,
56-
quant_params=QuantParams.from_inputs(input2),
56+
quant_params=QuantParams.from_inputs(input2, self._exported_program),
5757
)
5858
input2_id = vals_to_ids[input2]
5959

backends/xnnpack/operators/op_cat.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def define_node(
4646
tensor_input,
4747
xnn_graph,
4848
vals_to_ids,
49-
quant_params=QuantParams.from_inputs(tensor_input),
49+
quant_params=QuantParams.from_inputs(
50+
tensor_input, self._exported_program
51+
),
5052
)
5153

5254
self.define_tensor(

backends/xnnpack/operators/op_conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def define_node(
4646
kwargs = {}
4747
# input
4848
input_node = get_input_node(node, 0)
49-
input_quant_params = QuantParams.from_inputs(input_node)
49+
input_quant_params = QuantParams.from_inputs(input_node, self._exported_program)
5050
self.define_tensor(
5151
input_node,
5252
xnn_graph,

backends/xnnpack/operators/op_multiply.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def define_node(
4343
input1,
4444
xnn_graph,
4545
vals_to_ids,
46-
quant_params=QuantParams.from_inputs(input1),
46+
quant_params=QuantParams.from_inputs(input1, self._exported_program),
4747
)
4848
input1_id = vals_to_ids[input1]
4949

@@ -53,7 +53,7 @@ def define_node(
5353
input2,
5454
xnn_graph,
5555
vals_to_ids,
56-
quant_params=QuantParams.from_inputs(input2),
56+
quant_params=QuantParams.from_inputs(input2, self._exported_program),
5757
)
5858
input2_id = vals_to_ids[input2]
5959

backends/xnnpack/operators/op_sub.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def define_node(
4343
input1,
4444
xnn_graph,
4545
vals_to_ids,
46-
quant_params=QuantParams.from_inputs(input1),
46+
quant_params=QuantParams.from_inputs(input1, self._exported_program),
4747
)
4848
input1_id = vals_to_ids[input1]
4949

@@ -53,7 +53,7 @@ def define_node(
5353
input2,
5454
xnn_graph,
5555
vals_to_ids,
56-
quant_params=QuantParams.from_inputs(input2),
56+
quant_params=QuantParams.from_inputs(input2, self._exported_program),
5757
)
5858
input2_id = vals_to_ids[input2]
5959

backends/xnnpack/operators/op_to_copy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def define_node(
5151
)
5252

5353
input_node = get_input_node(node, 0)
54-
input_quant_params = QuantParams.from_inputs(input_node)
54+
input_quant_params = QuantParams.from_inputs(input_node, self._exported_program)
5555
output_quant_params = QuantParams.from_outputs(node)
5656

5757
permute_order = PERM_NCHW_TO_NHWC if to_channels_last else PERM_NHWC_TO_NCHW

backends/xnnpack/operators/quant_params.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
import torch
1212
from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass
1313
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
14-
from executorch.backends.xnnpack.utils.utils import check_or_raise
14+
from executorch.backends.xnnpack.utils.utils import check_or_raise, is_param_node
1515
from executorch.exir.dialects._ops import ops as exir_ops
16+
from torch.export import ExportedProgram
1617

1718

1819
class QuantParams:
@@ -178,11 +179,18 @@ def from_weights(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:
178179
return cls.from_q_dq_node(q)
179180

180181
@classmethod
181-
def from_inputs(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:
182+
def from_inputs(
183+
cls, tensor_node: torch.fx.Node, ep: ExportedProgram
184+
) -> Optional[QuantParams]:
182185
# tensor_node is quantized if it is produced by a dequant node
183186
if is_dequant(tensor_node) and TagImplicitQDqPass.is_tagged_as_implicit_q_dq(
184187
tensor_node
185188
):
189+
dq_input = cast(torch.fx.Node, tensor_node.args[0])
190+
if is_quant(dq_input):
191+
q_input = cast(torch.fx.Node, dq_input.args[0])
192+
if is_param_node(ep, q_input):
193+
return cls.from_q_dq_node(dq_input)
186194
return cls.from_q_dq_node(tensor_node)
187195

188196
return None

0 commit comments

Comments
 (0)