Skip to content

Commit 37c365d

Browse files
Arm backend: Refactor TosaArg to use tosa_spec
Refactor TosaArg to take TosaSpecification as an optional input. The spec is used for mapping torch.dtypes to to TOSA dtypes. Signed-off-by: Oscar Andersson <[email protected]> Change-Id: I2d52da4eededc36f8daefb3cb46214e9d374d306
1 parent db2fd03 commit 37c365d

File tree

6 files changed

+91
-52
lines changed

6 files changed

+91
-52
lines changed

backends/arm/_passes/convert_split_to_slice.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
2-
# All rights reserved.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65

76
# pyre-unsafe
87

98
import torch.fx
10-
from executorch.backends.arm._passes.arm_pass_utils import create_node
11-
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
9+
from executorch.backends.arm._passes.arm_pass_utils import (
10+
create_node,
11+
get_first_fake_tensor,
12+
)
1213
from executorch.exir.dialects._ops import ops as exir_ops
1314
from executorch.exir.pass_base import ExportPass, PassResult
1415

@@ -34,7 +35,7 @@ def call(self, graph_module: torch.fx.GraphModule):
3435
split_node = node
3536
input_node = split_node.all_input_nodes[0]
3637
output_nodes = split_node.users.copy()
37-
_, shape, _ = extract_tensor_meta(input_node.meta)
38+
shape = get_first_fake_tensor(input_node).shape
3839
rank = len(shape)
3940
split_lengths = split_node.args[1]
4041
dim = split_node.args[2] if len(split_node.args) > 2 else 0

backends/arm/operators/op_rescale.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
NodeVisitor,
1414
register_node_visitor,
1515
)
16-
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
16+
from executorch.backends.arm.tosa_mapping import TosaArg
1717
from executorch.backends.arm.tosa_quant_utils import create_const_ops_for_rescale
1818

1919
from executorch.backends.arm.tosa_specification import TosaSpecification
@@ -35,15 +35,15 @@ def define_node(
3535
) -> None:
3636
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
3737

38-
input_dtype = inputs[0].dtype
38+
input_dtype = node.all_input_nodes[0].meta["val"].dtype
3939
output_dtype = cast(torch.dtype, node.args[1])
4040
scale = cast(float, node.args[2])
4141
input_zp = cast(int, node.args[3])
4242
output_zp = cast(int, node.args[4])
4343

44-
if input_dtype != map_dtype(torch.int8) and input_zp != 0:
44+
if input_dtype != torch.int8 and input_zp != 0:
4545
raise ValueError(
46-
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ts.DTypeNames[input_dtype]}, {input_zp=}"
46+
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}"
4747
)
4848
if output_dtype != torch.int8 and output_zp != 0:
4949
raise ValueError(
@@ -91,15 +91,15 @@ def define_node(
9191
import serializer.tosa_serializer as ts # type: ignore
9292
from tosa.RoundingMode import RoundingMode # type: ignore
9393

94-
input_dtype = inputs[0].dtype
94+
input_dtype = node.all_input_nodes[0].meta["val"].dtype
9595
output_dtype = cast(torch.dtype, node.args[1])
9696
scale = cast(float, node.args[2])
9797
input_zp = cast(int, node.args[3])
9898
output_zp = cast(int, node.args[4])
9999

100-
if input_dtype != map_dtype(torch.int8) and input_zp != 0:
100+
if input_dtype != torch.int8 and input_zp != 0:
101101
raise ValueError(
102-
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ts.DTypeNames[input_dtype]}, {input_zp=}"
102+
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}"
103103
)
104104
if output_dtype != torch.int8 and output_zp != 0:
105105
raise ValueError(

backends/arm/process_node.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ def process_call_function(
3636
tosa_spec: TosaSpecification,
3737
):
3838
# Unpack arguments and convert
39-
inputs = getNodeArgs(node)
39+
inputs = getNodeArgs(node, tosa_spec)
4040

4141
# Convert output (this node itself)
4242
try:
43-
output = TosaArg(node)
43+
output = TosaArg(node, tosa_spec)
4444
except ValueError as e:
4545
raise ValueError(
4646
f"Failed processing call_function: {node.name}. "
@@ -78,7 +78,7 @@ def process_inputs(
7878
f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}"
7979
)
8080
try:
81-
tosa_arg = TosaArg(node)
81+
tosa_arg = TosaArg(node, tosa_spec)
8282
except ValueError as e:
8383
raise ValueError(
8484
f"Failed processing input placeholder: {node.name}. "
@@ -112,7 +112,7 @@ def process_inputs_to_parameters(
112112
):
113113
"""Serialize bias and non-quantized weights"""
114114
try:
115-
tosa_arg = TosaArg(node)
115+
tosa_arg = TosaArg(node, tosa_spec)
116116
except ValueError as e:
117117
raise ValueError(
118118
f"Failed processing parameter placeholder: {node.name}. "
@@ -137,10 +137,11 @@ def process_inputs_to_buffers(
137137
node: torch.fx.Node,
138138
tosa_graph: Any,
139139
edge_program: ExportedProgram,
140+
tosa_spec: TosaSpecification,
140141
):
141142
"""Serialize quantized weights"""
142143
try:
143-
tosa_arg = TosaArg(node)
144+
tosa_arg = TosaArg(node, tosa_spec)
144145
except ValueError as e:
145146
raise ValueError(
146147
f"Failed processing buffer placeholder: {node.name}. "
@@ -165,9 +166,10 @@ def process_inputs_to_lifted_tensor_constants(
165166
node: torch.fx.Node,
166167
tosa_graph: Any,
167168
edge_program: ExportedProgram,
169+
tosa_spec: TosaSpecification,
168170
):
169171
try:
170-
tosa_arg = TosaArg(node)
172+
tosa_arg = TosaArg(node, tosa_spec)
171173
except ValueError as e:
172174
raise ValueError(
173175
f"Failed processing lifted tensor constant placeholder: {node.name}. "
@@ -196,9 +198,11 @@ def process_placeholder(
196198
elif is_param(edge_program, node):
197199
process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec)
198200
elif is_buffer(edge_program, node):
199-
process_inputs_to_buffers(node, tosa_graph, edge_program)
201+
process_inputs_to_buffers(node, tosa_graph, edge_program, tosa_spec)
200202
elif is_lifted_tensor_constant(edge_program, node):
201-
process_inputs_to_lifted_tensor_constants(node, tosa_graph, edge_program)
203+
process_inputs_to_lifted_tensor_constants(
204+
node, tosa_graph, edge_program, tosa_spec
205+
)
202206
elif node.name in edge_program.graph_signature.inputs_to_lifted_custom_objs:
203207
raise NotImplementedError(
204208
"Placeholder is of type 'lifted custom object' which is not supported."

backends/arm/test/tester/arm_tester.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
)
4949
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
5050
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner
51+
from executorch.backends.arm.tosa_specification import TosaSpecification
5152

5253
from executorch.backends.xnnpack.test.tester import Tester
5354
from executorch.devtools.backend_debug import get_delegation_info
@@ -564,7 +565,10 @@ def dump_dtype_distribution(
564565
)
565566

566567
graph = self.get_graph(self.cur)
567-
dtype_dist_placeholders, dtype_dirst_tensors = _get_dtype_distribution(graph)
568+
tosa_spec = get_tosa_spec(self.compile_spec)
569+
dtype_dist_placeholders, dtype_dirst_tensors = _get_dtype_distribution(
570+
graph, tosa_spec
571+
)
568572
all_dtypes = set(dtype_dist_placeholders.keys()) | set(
569573
dtype_dirst_tensors.keys()
570574
)
@@ -659,7 +663,9 @@ def _compare_outputs(
659663
raise e
660664

661665

662-
def _get_dtype_distribution(graph: Graph) -> tuple[dict, dict]:
666+
def _get_dtype_distribution(
667+
graph: Graph, tosa_spec: TosaSpecification
668+
) -> tuple[dict, dict]:
663669
"""Counts the occurences of placeholder and call_function dtypes in a graph.
664670
The result is a tuple of Counters (placeholder_distribution, call_function_distribution)
665671
"""
@@ -670,7 +676,7 @@ def _get_dtype_distribution(graph: Graph) -> tuple[dict, dict]:
670676
placeholder_dtypes.append(str(node.meta["val"].dtype))
671677
if node.op == "call_function":
672678
if "val" in node.meta:
673-
dtype, _, _ = extract_tensor_meta(node.meta)
679+
dtype, _, _ = extract_tensor_meta(node.meta, tosa_spec)
674680
call_function_dtypes.append(ts.DTypeNames[dtype])
675681
return Counter(placeholder_dtypes), Counter(call_function_dtypes)
676682

backends/arm/tosa_mapping.py

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@
1111
# the standardised TOSA representation.
1212
#
1313

14-
from typing import Any, Sequence
14+
from typing import Any, Optional, Sequence
1515

1616
import torch
17-
18-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
19-
17+
from executorch.backends.arm.tosa_specification import (
18+
Tosa_0_80,
19+
Tosa_1_00,
20+
TosaSpecification,
21+
)
2022

2123
UNSUPPORTED_DTYPES = (
2224
torch.float64,
@@ -30,33 +32,39 @@
3032
torch.long,
3133
)
3234

33-
DTYPE_MAP = {
34-
torch.float32: ts.DType.FP32,
35-
torch.float: ts.DType.FP32,
36-
torch.float16: ts.DType.FP16,
37-
torch.half: ts.DType.FP16,
38-
torch.bfloat16: ts.DType.BF16,
39-
torch.int8: ts.DType.INT8,
40-
torch.int16: ts.DType.INT16,
41-
torch.short: ts.DType.INT16,
42-
torch.int32: ts.DType.INT32,
43-
torch.int: ts.DType.INT32,
44-
torch.bool: ts.DType.BOOL,
45-
}
46-
47-
48-
def map_dtype(data_type: torch.dtype) -> ts.DType:
35+
36+
def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any:
4937
if data_type in UNSUPPORTED_DTYPES:
5038
raise ValueError(f"Unsupported type: {data_type}")
51-
if data_type not in DTYPE_MAP:
39+
if isinstance(tosa_spec, Tosa_0_80):
40+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
41+
elif isinstance(tosa_spec, Tosa_1_00):
42+
import serializer.tosa_serializer as ts # type: ignore
43+
else:
44+
raise RuntimeError(f"Unsupported tosa_spec: {tosa_spec}")
45+
46+
dtype_map = {
47+
torch.float32: ts.DType.FP32,
48+
torch.float: ts.DType.FP32,
49+
torch.float16: ts.DType.FP16,
50+
torch.half: ts.DType.FP16,
51+
torch.bfloat16: ts.DType.BF16,
52+
torch.int8: ts.DType.INT8,
53+
torch.int16: ts.DType.INT16,
54+
torch.short: ts.DType.INT16,
55+
torch.int32: ts.DType.INT32,
56+
torch.int: ts.DType.INT32,
57+
torch.bool: ts.DType.BOOL,
58+
}
59+
if data_type not in dtype_map:
5260
raise ValueError(f"Unknown type: {data_type}")
53-
return DTYPE_MAP[data_type]
61+
return dtype_map[data_type]
5462

5563

5664
# Returns the shape and type of a node
5765
# TODO: other types, can be
5866
# SymInt, FakeTensor, a List[Union[FakeTensor, SymInt]], or None
59-
def extract_tensor_meta(meta):
67+
def extract_tensor_meta(meta, tosa_spec: TosaSpecification):
6068
assert meta.get("val") is not None
6169
val = meta["val"]
6270
if type(val) is tuple:
@@ -67,7 +75,7 @@ def extract_tensor_meta(meta):
6775
raise ValueError(
6876
f"Expected first value in node.meta['val'] to be FakeTensor, got {val.__class__}"
6977
)
70-
dtype = map_dtype(val.dtype)
78+
dtype = map_dtype(val.dtype, tosa_spec)
7179
shape = tuple(val.size())
7280

7381
if meta.get("tosa_dim_order") is not None:
@@ -81,17 +89,28 @@ def extract_tensor_meta(meta):
8189
class TosaArg:
8290
def __process_node(self, argument: torch.fx.Node):
8391
self.name: str = argument.name
84-
self.dtype, self.shape, self.dim_order = extract_tensor_meta(argument.meta)
92+
self.dtype, self.shape, self.dim_order = extract_tensor_meta(
93+
argument.meta, self.tosa_spec
94+
)
8595

8696
def __process_list(self, argument):
8797
self.special: list = list(argument)
8898

8999
def __process_number(self, argument: float | int):
90100
self.number: float | int = argument
91101

92-
def __init__(self, argument: Any) -> None:
102+
def __init__(
103+
self, argument: Any, tosa_spec: Optional[TosaSpecification] = None
104+
) -> None:
93105
if argument is None:
94106
return
107+
if tosa_spec is None:
108+
raise ValueError("tosa_spec is None")
109+
elif not isinstance(tosa_spec, TosaSpecification):
110+
raise ValueError(
111+
f"Expected tosa_spec to be a TosaSpecification, but got {tosa_spec}"
112+
)
113+
self.tosa_spec = tosa_spec
95114

96115
if isinstance(argument, torch.fx.Node):
97116
self.__process_node(argument)
@@ -116,6 +135,12 @@ def __repr__(self):
116135
if self.name is not None:
117136
attrs.append(f"name={self.name!r}")
118137
if self.dtype is not None:
138+
if isinstance(self.tosa_spec, Tosa_0_80):
139+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
140+
elif isinstance(self.tosa_spec, Tosa_1_00):
141+
import serializer.tosa_serializer as ts # type: ignore
142+
else:
143+
raise RuntimeError(f"Unsupported tosa_spec: {self.tosa_spec}")
119144
attrs.append(f"dtype={ts.DTypeNames[self.dtype]}")
120145
if self.shape is not None:
121146
attrs.append(f"shape={self.shape!r}")
@@ -125,4 +150,6 @@ def __repr__(self):
125150
attrs.append(f"special={self.special!r}")
126151
if hasattr(self, "number") and self.number is not None:
127152
attrs.append(f"number={self.number!r}")
153+
if hasattr(self, "tosa_spec") and self.tosa_spec is not None:
154+
attrs.append(f"tosa_spec={self.tosa_spec!r}")
128155
return f"{self.__class__.__name__}({', '.join(attrs)})"

backends/arm/tosa_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
from typing import Any, Optional, Tuple
1111

1212
import torch
13-
1413
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1514
from executorch.backends.arm.tosa_mapping import TosaArg
1615

16+
from executorch.backends.arm.tosa_specification import TosaSpecification
17+
1718
from executorch.exir.dialects._ops import ops as exir_ops
1819
from executorch.exir.print_program import inspect_node
1920
from torch.fx import Node
@@ -93,9 +94,9 @@ def dbg_fail(
9394
dbg_node(node, graph_module)
9495

9596

96-
def getNodeArgs(node: Node) -> list[TosaArg]:
97+
def getNodeArgs(node: Node, tosa_spec: TosaSpecification) -> list[TosaArg]:
9798
try:
98-
return [TosaArg(arg) for arg in node.args]
99+
return [TosaArg(arg, tosa_spec) for arg in node.args]
99100
except ValueError as e:
100101
raise ValueError(f"Failed processing args to op:\n{node}") from e
101102

0 commit comments

Comments
 (0)