Skip to content

Commit 9b72314

Browse files
committed
Remove hard coded argument types
Move away from implicitly assuming arguments are torch.int8 and figure out the type from the quantization nodes instead. This is done to prepare for breaking up the TOSA conversion and serialization into separate parts. Signed-off-by: Per Åstrand <[email protected]> Change-Id: Id88ef8f264e6af8e90a92a00fca13cdbcc857bab
1 parent 093e735 commit 9b72314

File tree

3 files changed

+40
-5
lines changed

3 files changed

+40
-5
lines changed

backends/arm/arm_backend.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from executorch.backends.arm.arm_vela import vela_compile
1818
from executorch.backends.arm.operators.node_visitor import get_node_visitors
1919
from executorch.backends.arm.operators.op_placeholder import process_placeholder
20-
from executorch.backends.arm.tosa_mapping import TosaArg
21-
from executorch.backends.arm.tosa_quant_utils import is_quant_node
20+
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
21+
from executorch.backends.arm.tosa_quant_utils import get_quant_node_dtype, is_quant_node
2222
from executorch.backends.arm.tosa_utils import (
2323
dbg_fail,
2424
dbg_tosa_dump,
@@ -280,7 +280,11 @@ def preprocess( # noqa: C901
280280
if is_permute_node_before_addmm(node)
281281
else output.shape
282282
),
283-
ts.DType.INT8 if is_quant_node(node) else output.dtype,
283+
(
284+
map_dtype(get_quant_node_dtype(node))
285+
if is_quant_node(node)
286+
else output.dtype
287+
),
284288
)
285289

286290
# Visiting each Node

backends/arm/operators/op_placeholder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from executorch.backends.arm.tosa_mapping import TosaArg
1010
from executorch.backends.arm.tosa_quant_utils import (
11+
get_quant_arg_dtype,
1112
get_quant_node_args,
1213
is_quant_arg,
1314
q_op,
@@ -166,7 +167,7 @@ def process_placeholder(
166167
tensor = ts.TosaSerializerTensor(
167168
inputs[0].name,
168169
input_shape,
169-
ts.DType.INT8 if is_quant_arg(node) else inputs[0].dtype,
170+
get_quant_arg_dtype(node) if is_quant_arg(node) else inputs[0].dtype,
170171
data=None,
171172
placeholderFilename=inputs[0].name + ".npy",
172173
)

backends/arm/tosa_quant_utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import serializer.tosa_serializer as ts
1212
import torch.fx
13-
from executorch.backends.arm.tosa_mapping import TosaArg
13+
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
1414
from executorch.exir.dialects._ops import ops as exir_ops
1515
from serializer.tosa_serializer import TosaOp, TosaSerializerTensor
1616

@@ -45,11 +45,41 @@ def is_quant_node(node: torch.fx.Node):
4545
)
4646

4747

48+
def get_quant_node_dtype(node: torch.fx.Node):
49+
if "tosa" in node.target.__name__:
50+
return node.meta["val"].dtype
51+
52+
if node.target in dq_q_ops:
53+
return node.args[5]
54+
55+
# if not a tosa node, nor a q/dq op, walk the graph until we find a q op
56+
consumer_node = list(node.users)[0]
57+
while True:
58+
if consumer_node.target in dq_q_ops:
59+
return consumer_node.args[5]
60+
61+
# Try to move on to the next node
62+
if len(consumer_node.users) == 0:
63+
raise RuntimeError("No quantized node found in graph")
64+
consumer_node = list(consumer_node.users)[0]
65+
66+
4867
def is_quant_arg(arg):
4968
consumer_node = list(arg.users)[0]
5069
return consumer_node.target == q_op
5170

5271

72+
def get_quant_arg_dtype(node: torch.fx.Node):
73+
consumer_node = list(node.users)[0]
74+
75+
# Get type of quant node, args differ from per_tensor and per_channel.
76+
if consumer_node.target == q_op:
77+
if is_quant_arg(node):
78+
return map_dtype(consumer_node.args[5])
79+
else:
80+
raise RuntimeError("Quantization argument not found")
81+
82+
5383
def get_quant_node_args(node: torch.fx.Node):
5484
"""
5585
Get the quantization parameters from a quant node.

0 commit comments

Comments
 (0)