|
9 | 9 | import torch
|
10 | 10 |
|
11 | 11 | from torch._ops import OpOverload
|
| 12 | +from torch._subclasses import FakeTensor |
12 | 13 |
|
13 | 14 | from torch.ao.quantization.quantizer import (
|
14 | 15 | QuantizationAnnotation,
|
@@ -42,6 +43,19 @@ def decorator(annotator: Callable):
|
42 | 43 | return decorator
|
43 | 44 |
|
44 | 45 |
|
| 46 | +def _is_input_float_tensor(node: Node): |
| 47 | + """Check if the input is not a float tensor, so that we can skip quantization for the node |
| 48 | + since observers only works with float Tensors |
| 49 | + """ |
| 50 | + if ( |
| 51 | + not isinstance(node, Node) |
| 52 | + or "val" not in node.meta |
| 53 | + or not isinstance(node.meta["val"], FakeTensor) |
| 54 | + ): |
| 55 | + return False |
| 56 | + return node.meta["val"].dtype == torch.float32 |
| 57 | + |
| 58 | + |
45 | 59 | def _is_annotated(nodes: List[Node]):
|
46 | 60 | """
|
47 | 61 | Given a list of nodes (that represents an operator pattern),
|
@@ -123,11 +137,11 @@ def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None
|
123 | 137 |
|
124 | 138 | input_qspec_map = {}
|
125 | 139 | input_act0 = node.args[0]
|
126 |
| - if isinstance(input_act0, Node): |
| 140 | + if _is_input_float_tensor(input_act0): |
127 | 141 | input_qspec_map[input_act0] = input_act_qspec
|
128 | 142 |
|
129 | 143 | input_act1 = node.args[1]
|
130 |
| - if isinstance(input_act1, Node): |
| 144 | + if _is_input_float_tensor(input_act1): |
131 | 145 | input_qspec_map[input_act1] = input_act_qspec
|
132 | 146 |
|
133 | 147 | node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
|
|
0 commit comments