Skip to content

Commit 4909db1

Browse files
Arm backend: Update rescale to handle more dtypes
Update op_rescale to handle other dtype conversion than int8 <-> int32 for TOSA 1.0. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: Icc19fb8bb391ec063df2f4cb7dddaf8db672332f
1 parent fbb3ad1 commit 4909db1

File tree

2 files changed

+29
-41
lines changed

2 files changed

+29
-41
lines changed

backends/arm/operators/op_rescale.py

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
from executorch.backends.arm.operators.operator_validation_utils import (
1717
validate_num_inputs,
1818
)
19-
from executorch.backends.arm.tosa_mapping import TosaArg
20-
from executorch.backends.arm.tosa_quant_utils import create_const_ops_for_rescale
19+
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
20+
from executorch.backends.arm.tosa_quant_utils import build_rescale
2121

2222
from executorch.backends.arm.tosa_specification import TosaSpecification
2323
from torch.fx import Node
@@ -98,53 +98,29 @@ def define_node(
9898

9999
validate_num_inputs(self.target, inputs, 5)
100100

101-
input_dtype = node.all_input_nodes[0].meta["val"].dtype
101+
input_dtype = inputs[0].dtype
102102
output_dtype = cast(torch.dtype, node.args[1])
103103
scale = cast(float, node.args[2])
104104
input_zp = cast(int, node.args[3])
105105
output_zp = cast(int, node.args[4])
106106

107-
if input_dtype != torch.int8 and input_zp != 0:
107+
if input_dtype != map_dtype(torch.int8, self.tosa_spec) and input_zp != 0:
108108
raise ValueError(
109109
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}"
110110
)
111111
if output_dtype != torch.int8 and output_zp != 0:
112112
raise ValueError(
113-
f"If output dtype is not int8, output_zp must be 0. Got {output_dtype=}, {output_zp=}"
113+
f"If output dtype is not int8, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}"
114114
)
115115

116-
# scale32 gives higher accuracy but for a higher HW cost.
117-
# For now, always go for scale32.
118-
scale_32 = True
119-
scale_width = 32 if scale_32 else 16
120-
multipliers, shifts = tosa_quant_utils.compute_multiplier_and_shift(
121-
[scale], scale_width
122-
)
123-
124-
rescale_inputs = create_const_ops_for_rescale(
116+
build_rescale(
125117
tosa_graph,
126-
input_dtype,
127-
inputs[0].name,
128-
multipliers,
129-
shifts,
130-
input_zp,
131-
output_zp,
132-
ts,
133-
)
134-
135-
attr_rescale = ts.TosaSerializerAttribute()
136-
137-
attr_rescale.RescaleAttribute(
138-
scale32=scale_32,
118+
scale=[scale],
119+
input_node=inputs[0],
120+
output_name=output.name,
121+
output_type=output.dtype,
122+
input_zp=input_zp,
123+
output_zp=output_zp,
139124
rounding_mode=RoundingMode.SINGLE_ROUND,
140125
per_channel=False,
141-
input_unsigned=False,
142-
output_unsigned=False,
143-
)
144-
145-
tosa_graph.addOperator(
146-
ts.TosaOp.Op().RESCALE,
147-
[inputs[0].name, *rescale_inputs],
148-
[output.name],
149-
attr_rescale,
150126
)

backends/arm/tosa_quant_utils.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,13 +236,21 @@ def build_rescale_v0_80(
236236
# For TOSA spec v1.0 RESCALE operator requires multipler, shifts, input_zp and output_zp to be
237237
# const inputs. Create constant operators from the data already initialized.
238238
def create_const_ops_for_rescale(
239-
tosa_fb, input_dtype, input_name, multipliers, shifts, input_zp, output_zp, ts
239+
tosa_fb,
240+
scale_32,
241+
input_dtype,
242+
input_name,
243+
multipliers,
244+
shifts,
245+
input_zp,
246+
output_zp,
247+
output_dtype,
248+
ts,
240249
):
241-
output_dtype = ts.DType.INT32 if input_dtype == ts.DType.INT8 else ts.DType.INT8
242250

243251
multipliers = tosa_fb.addConst(
244252
(len(multipliers),),
245-
ts.DType.INT32,
253+
ts.DType.INT32 if scale_32 else ts.DType.INT16,
246254
multipliers,
247255
name=input_name + "_multipliers",
248256
)
@@ -275,20 +283,24 @@ def build_rescale(
275283

276284
input_name = input_node.name
277285

278-
multipliers, shifts = compute_multiplier_and_shift(scale, 32)
286+
scaleWidth = 32
287+
is_scale32 = True
288+
multipliers, shifts = compute_multiplier_and_shift(scale, scaleWidth)
279289
rescale_inputs = create_const_ops_for_rescale(
280290
tosa_fb,
291+
is_scale32,
281292
input_node.dtype,
282293
input_name,
283294
multipliers,
284295
shifts,
285296
input_zp,
286297
output_zp,
298+
output_type,
287299
ts,
288300
)
289301
attr_rescale = ts.TosaSerializerAttribute()
290302
attr_rescale.RescaleAttribute(
291-
scale32=True,
303+
scale32=is_scale32,
292304
rounding_mode=rounding_mode,
293305
per_channel=per_channel,
294306
input_unsigned=False,

0 commit comments

Comments
 (0)