Skip to content

Commit 0ac9c5f

Browse files
mcr229facebook-github-bot
authored andcommitted
quant_params should come from implicit nodes (#301)
Summary: Pull Request resolved: #301 QuantParams are generated for nodes to imply that these operations are quantized operations. However, if the QuantNodes are explicit, i.e. being used to actually convert data type, then we should not be generating quant_params. This is to help with mixed-datatype modles Reviewed By: salilsdesai Differential Revision: D48805092 fbshipit-source-id: 7a7423e93cdb2629842701ad8271a18464e60c33
1 parent 5c5ce43 commit 0ac9c5f

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

backends/xnnpack/operators/quant_params.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import cast, Optional, Union
1010

1111
import torch
12+
from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass
1213
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
1314
from executorch.backends.xnnpack.utils.utils import check_or_raise
1415
from executorch.exir.dialects._ops import ops as exir_ops
@@ -179,7 +180,9 @@ def from_weights(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:
179180
@classmethod
180181
def from_inputs(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:
181182
# tensor_node is quantized if it is produced by a dequant node
182-
if is_dequant(tensor_node):
183+
if is_dequant(tensor_node) and TagImplicitQDqPass.is_tagged_as_implicit_q_dq(
184+
tensor_node
185+
):
183186
return cls.from_q_dq_node(tensor_node)
184187

185188
return None
@@ -190,7 +193,7 @@ def from_outputs(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:
190193
if len(tensor_node.users) == 1:
191194
q = list(tensor_node.users.keys())[0]
192195
# Check if user is a q node
193-
if is_quant(q):
196+
if is_quant(q) and TagImplicitQDqPass.is_tagged_as_implicit_q_dq(q):
194197
return cls.from_q_dq_node(q)
195198

196199
return None

0 commit comments

Comments
 (0)