Skip to content

Commit 4c0854b

Browse files
mcr229facebook-github-bot
authored andcommitted
quant_params should come from implicit nodes
Differential Revision: D48805092 fbshipit-source-id: 41c8ac41060281c77f1f04c40b4784022b880423
1 parent 5e68f05 commit 4c0854b

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

backends/xnnpack/operators/quant_params.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import cast, Optional, Union
1010

1111
import torch
12+
from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass
1213
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
1314
from executorch.backends.xnnpack.utils.utils import check_or_raise
1415
from executorch.exir.dialects._ops import ops as exir_ops
@@ -179,7 +180,9 @@ def from_weights(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:
179180
@classmethod
180181
def from_inputs(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:
181182
# tensor_node is quantized if it is produced by a dequant node
182-
if is_dequant(tensor_node):
183+
if is_dequant(tensor_node) and TagImplicitQDqPass.is_tagged_as_implicit_q_dq(
184+
tensor_node
185+
):
183186
return cls.from_q_dq_node(tensor_node)
184187

185188
return None
@@ -190,7 +193,7 @@ def from_outputs(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:
190193
if len(tensor_node.users) == 1:
191194
q = list(tensor_node.users.keys())[0]
192195
# Check if user is a q node
193-
if is_quant(q):
196+
if is_quant(q) and TagImplicitQDqPass.is_tagged_as_implicit_q_dq(q):
194197
return cls.from_q_dq_node(q)
195198

196199
return None

0 commit comments

Comments
 (0)