Skip to content

Commit e078e93

Browse files
authored
Skip annotate boolean input (#2957) (#3051)
* Skip annotate boolean input (#2957) Summary: Pull Request resolved: #2957 ghstack-source-id: 222200589 exported-using-ghexport It only makes sense to quantize fp tensor, but not boolean. Add a check to make sure only fp tensor are annotated in quantizer Reviewed By: jerryzh168 Differential Revision: D55946526 fbshipit-source-id: d94bfee38ab2d29fc9672ab631b4d5d0c5239d25 * fix lint
1 parent 925f674 commit e078e93

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

backends/qualcomm/quantizer/utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010

1111
from torch._ops import OpOverload
12+
from torch._subclasses import FakeTensor
1213

1314
from torch.ao.quantization.quantizer import (
1415
QuantizationAnnotation,
@@ -42,6 +43,19 @@ def decorator(annotator: Callable):
4243
return decorator
4344

4445

46+
def _is_input_float_tensor(node: Node):
47+
"""Check if the input is not a float tensor, so that we can skip quantization for the node
48+
since observers only works with float Tensors
49+
"""
50+
if (
51+
not isinstance(node, Node)
52+
or "val" not in node.meta
53+
or not isinstance(node.meta["val"], FakeTensor)
54+
):
55+
return False
56+
return node.meta["val"].dtype == torch.float32
57+
58+
4559
def _is_annotated(nodes: List[Node]):
4660
"""
4761
Given a list of nodes (that represents an operator pattern),
@@ -123,11 +137,11 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None
123137

124138
input_qspec_map = {}
125139
input_act0 = node.args[0]
126-
if isinstance(input_act0, Node):
140+
if _is_input_float_tensor(input_act0):
127141
input_qspec_map[input_act0] = input_act_qspec
128142

129143
input_act1 = node.args[1]
130-
if isinstance(input_act1, Node):
144+
if _is_input_float_tensor(input_act1):
131145
input_qspec_map[input_act1] = input_act_qspec
132146

133147
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(

0 commit comments

Comments
 (0)