|
16 | 16 | from executorch.backends.arm.operators.operator_validation_utils import (
|
17 | 17 | validate_num_inputs,
|
18 | 18 | )
|
19 |
| -from executorch.backends.arm.tosa_mapping import TosaArg |
20 |
| -from executorch.backends.arm.tosa_quant_utils import create_const_ops_for_rescale |
| 19 | +from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg |
| 20 | +from executorch.backends.arm.tosa_quant_utils import build_rescale |
21 | 21 |
|
22 | 22 | from executorch.backends.arm.tosa_specification import TosaSpecification
|
23 | 23 | from torch.fx import Node
|
@@ -98,53 +98,29 @@ def define_node(
|
98 | 98 |
|
99 | 99 | validate_num_inputs(self.target, inputs, 5)
|
100 | 100 |
|
101 |
| - input_dtype = node.all_input_nodes[0].meta["val"].dtype |
| 101 | + input_dtype = inputs[0].dtype |
102 | 102 | output_dtype = cast(torch.dtype, node.args[1])
|
103 | 103 | scale = cast(float, node.args[2])
|
104 | 104 | input_zp = cast(int, node.args[3])
|
105 | 105 | output_zp = cast(int, node.args[4])
|
106 | 106 |
|
107 |
| - if input_dtype != torch.int8 and input_zp != 0: |
| 107 | + if input_dtype != map_dtype(torch.int8, self.tosa_spec) and input_zp != 0: |
108 | 108 | raise ValueError(
|
109 | 109 | f"If input dtype is not int8, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}"
|
110 | 110 | )
|
111 | 111 | if output_dtype != torch.int8 and output_zp != 0:
|
112 | 112 | raise ValueError(
|
113 |
| - f"If output dtype is not int8, output_zp must be 0. Got {output_dtype=}, {output_zp=}" |
| 113 | + f"If output dtype is not int8, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}" |
114 | 114 | )
|
115 | 115 |
|
116 |
| - # scale32 gives higher accuracy but for a higher HW cost. |
117 |
| - # For now, always go for scale32. |
118 |
| - scale_32 = True |
119 |
| - scale_width = 32 if scale_32 else 16 |
120 |
| - multipliers, shifts = tosa_quant_utils.compute_multiplier_and_shift( |
121 |
| - [scale], scale_width |
122 |
| - ) |
123 |
| - |
124 |
| - rescale_inputs = create_const_ops_for_rescale( |
| 116 | + build_rescale( |
125 | 117 | tosa_graph,
|
126 |
| - input_dtype, |
127 |
| - inputs[0].name, |
128 |
| - multipliers, |
129 |
| - shifts, |
130 |
| - input_zp, |
131 |
| - output_zp, |
132 |
| - ts, |
133 |
| - ) |
134 |
| - |
135 |
| - attr_rescale = ts.TosaSerializerAttribute() |
136 |
| - |
137 |
| - attr_rescale.RescaleAttribute( |
138 |
| - scale32=scale_32, |
| 118 | + scale=[scale], |
| 119 | + input_node=inputs[0], |
| 120 | + output_name=output.name, |
| 121 | + output_type=output.dtype, |
| 122 | + input_zp=input_zp, |
| 123 | + output_zp=output_zp, |
139 | 124 | rounding_mode=RoundingMode.SINGLE_ROUND,
|
140 | 125 | per_channel=False,
|
141 |
| - input_unsigned=False, |
142 |
| - output_unsigned=False, |
143 |
| - ) |
144 |
| - |
145 |
| - tosa_graph.addOperator( |
146 |
| - ts.TosaOp.Op().RESCALE, |
147 |
| - [inputs[0].name, *rescale_inputs], |
148 |
| - [output.name], |
149 |
| - attr_rescale, |
150 | 126 | )
|
0 commit comments