Skip to content

Commit c550354

Browse files
mcr229facebook-github-bot
authored andcommitted
quant params from static inputs
Differential Revision: D49850149 fbshipit-source-id: fb8910116d1b2ac5eecdccdd2f3a64eb275fd122
1 parent 79f3db2 commit c550354

File tree

8 files changed

+25
-13
lines changed

8 files changed

+25
-13
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def define_nodes_tensor_inputs_outputs(
423423
inp,
424424
xnn_graph,
425425
vals_to_ids,
426-
quant_params=QuantParams.from_inputs(inp),
426+
quant_params=QuantParams.from_inputs(inp, self._exported_program),
427427
convert_to_nhwc=convert_to_nhwc,
428428
)
429429
else:
@@ -434,7 +434,9 @@ def define_nodes_tensor_inputs_outputs(
434434
)
435435
# Define Input Node
436436
input_node = get_input_node(node, input_type_map.node_input)
437-
input_quant_params = QuantParams.from_inputs(input_node)
437+
input_quant_params = QuantParams.from_inputs(
438+
input_node, self._exported_program
439+
)
438440
self.define_tensor(
439441
input_node,
440442
xnn_graph,

backends/xnnpack/operators/op_add.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def define_node(
4343
input1,
4444
xnn_graph,
4545
vals_to_ids,
46-
quant_params=QuantParams.from_inputs(input1),
46+
quant_params=QuantParams.from_inputs(input1, self._exported_program),
4747
)
4848
input1_id = vals_to_ids[input1]
4949

@@ -53,7 +53,7 @@ def define_node(
5353
input2,
5454
xnn_graph,
5555
vals_to_ids,
56-
quant_params=QuantParams.from_inputs(input2),
56+
quant_params=QuantParams.from_inputs(input2, self._exported_program),
5757
)
5858
input2_id = vals_to_ids[input2]
5959

backends/xnnpack/operators/op_cat.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def define_node(
4646
tensor_input,
4747
xnn_graph,
4848
vals_to_ids,
49-
quant_params=QuantParams.from_inputs(tensor_input),
49+
quant_params=QuantParams.from_inputs(
50+
tensor_input, self._exported_program
51+
),
5052
)
5153

5254
self.define_tensor(

backends/xnnpack/operators/op_conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def define_node(
4646
kwargs = {}
4747
# input
4848
input_node = get_input_node(node, 0)
49-
input_quant_params = QuantParams.from_inputs(input_node)
49+
input_quant_params = QuantParams.from_inputs(input_node, self._exported_program)
5050
self.define_tensor(
5151
input_node,
5252
xnn_graph,

backends/xnnpack/operators/op_multiply.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def define_node(
4343
input1,
4444
xnn_graph,
4545
vals_to_ids,
46-
quant_params=QuantParams.from_inputs(input1),
46+
quant_params=QuantParams.from_inputs(input1, self._exported_program),
4747
)
4848
input1_id = vals_to_ids[input1]
4949

@@ -53,7 +53,7 @@ def define_node(
5353
input2,
5454
xnn_graph,
5555
vals_to_ids,
56-
quant_params=QuantParams.from_inputs(input2),
56+
quant_params=QuantParams.from_inputs(input2, self._exported_program),
5757
)
5858
input2_id = vals_to_ids[input2]
5959

backends/xnnpack/operators/op_sub.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def define_node(
4343
input1,
4444
xnn_graph,
4545
vals_to_ids,
46-
quant_params=QuantParams.from_inputs(input1),
46+
quant_params=QuantParams.from_inputs(input1, self._exported_program),
4747
)
4848
input1_id = vals_to_ids[input1]
4949

@@ -53,7 +53,7 @@ def define_node(
5353
input2,
5454
xnn_graph,
5555
vals_to_ids,
56-
quant_params=QuantParams.from_inputs(input2),
56+
quant_params=QuantParams.from_inputs(input2, self._exported_program),
5757
)
5858
input2_id = vals_to_ids[input2]
5959

backends/xnnpack/operators/op_to_copy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def define_node(
5151
)
5252

5353
input_node = get_input_node(node, 0)
54-
input_quant_params = QuantParams.from_inputs(input_node)
54+
input_quant_params = QuantParams.from_inputs(input_node, self._exported_program)
5555
output_quant_params = QuantParams.from_outputs(node)
5656

5757
permute_order = PERM_NCHW_TO_NHWC if to_channels_last else PERM_NHWC_TO_NCHW

backends/xnnpack/operators/quant_params.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
import torch
1212
from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass
1313
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
14-
from executorch.backends.xnnpack.utils.utils import check_or_raise
14+
from executorch.backends.xnnpack.utils.utils import check_or_raise, is_param_node
1515
from executorch.exir.dialects._ops import ops as exir_ops
16+
from torch.export import ExportedProgram
1617

1718

1819
class QuantParams:
@@ -178,11 +179,18 @@ def from_weights(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:
178179
return cls.from_q_dq_node(q)
179180

180181
@classmethod
181-
def from_inputs(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:
182+
def from_inputs(
183+
cls, tensor_node: torch.fx.Node, ep: ExportedProgram
184+
) -> Optional[QuantParams]:
182185
# tensor_node is quantized if it is produced by a dequant node
183186
if is_dequant(tensor_node) and TagImplicitQDqPass.is_tagged_as_implicit_q_dq(
184187
tensor_node
185188
):
189+
dq_input = cast(torch.fx.Node, tensor_node.args[0])
190+
if is_quant(dq_input):
191+
q_input = cast(torch.fx.Node, dq_input.args[0])
192+
if is_param_node(ep, q_input):
193+
return cls.from_q_dq_node(dq_input)
186194
return cls.from_q_dq_node(tensor_node)
187195

188196
return None

0 commit comments

Comments
 (0)