Skip to content

Commit 02d315d

Browse files
Arm backend: Check in tosa.fbs for TOSA 0.80 and 1.0 (#10870)
### Summary Add schema files for 0.80 and 1.0. This will enable dump_artifact and dump_operator_distribution for both 0.80 and 1.0. Also bumps reference model. ### Test plan Tested on internal and external GitHub CI. Signed-off-by: Oscar Andersson <[email protected]>
1 parent b09e793 commit 02d315d

File tree

5 files changed

+871
-2
lines changed

5 files changed

+871
-2
lines changed

backends/arm/scripts/install_reference_model.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ tosa_reference_model_url="https://git.gitlab.arm.com/tosa/tosa-reference-model.g
1313
tosa_reference_model_0_80_branch="v0.80"
1414
tosa_reference_model_0_80_rev="70ed0b40fa831387e36abdb4f7fb9670a3464f5a"
1515
tosa_serialization_lib_0_80_rev="v0.80.1"
16-
tosa_reference_model_1_0_rev="f9b4ceb850964be03a39e965ad7a0546dc6c57ae"
16+
tosa_reference_model_1_0_rev="4d17b5b960cd986d8cb8052188fbe3ae494789e8"
1717

1818
script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
1919

backends/arm/test/runner_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from torch.fx.node import Node
3232

3333
from torch.overrides import TorchFunctionMode
34+
from tosa.TosaGraph import TosaGraph
3435

3536
logger = logging.getLogger(__name__)
3637

@@ -461,10 +462,19 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict:
461462
tosa_input_file = os.path.join(tmp, "output.tosa")
462463
with open(tosa_input_file, "wb") as f:
463464
f.write(tosa_fb)
465+
tosa_graph = TosaGraph.GetRootAsTosaGraph(tosa_fb)
466+
version = tosa_graph.Version()
467+
major = version._Major()
468+
minor = version._Minor()
469+
patch = version._Patch()
470+
if not ((major == 1 and minor == 0) or (major == 0 and minor == 80)):
471+
raise RuntimeError(
472+
f"Unsupported version in TOSA flatbuffer: version={major}.{minor}.{patch}"
473+
)
464474

465475
arm_backend_path = os.path.realpath(os.path.dirname(__file__) + "/..")
466476
tosa_schema_file = os.path.join(
467-
arm_backend_path, "third-party/serialization_lib/schema/tosa.fbs"
477+
arm_backend_path, f"tosa/schemas/tosa_{major}.{minor}.fbs"
468478
)
469479
assert os.path.exists(
470480
tosa_schema_file

backends/arm/tosa/schemas/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# License
2+
3+
The FlatBuffer schema (fbs) files originates from
4+
https://git.mlplatform.org/tosa/reference_model.git/ and are relicensed under the BSD-style license
5+
file found in the [LICENSE](../../../../LICENSE) file in the root directory of this source tree.
Lines changed: 314 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,314 @@
1+
// Copyright 2025 Arm Limited and/or its affiliates.
2+
//
3+
// This source code is licensed under the BSD-style license found in the
4+
// LICENSE file in the root directory of this source tree.
5+
6+
namespace tosa;
7+
8+
// This corresponds to the version.
9+
file_identifier "TOSA";
10+
// File extension of any written files.
11+
file_extension "tosa";
12+
13+
// NOTE: New values added to the schema should be placed
14+
// at the end of the list in order to keep schema stable.
15+
16+
enum DType:uint32 {
17+
UNKNOWN = 0,
18+
BOOL,
19+
UINT8,
20+
INT4,
21+
INT8,
22+
INT16,
23+
INT32,
24+
INT48,
25+
FP32,
26+
UINT16,
27+
FP16,
28+
BF16,
29+
SHAPE,
30+
}
31+
32+
enum ResizeMode:uint32 {
33+
UNKNOWN = 0,
34+
NEAREST,
35+
BILINEAR,
36+
}
37+
38+
enum Op:uint32 {
39+
UNKNOWN = 0,
40+
ARGMAX,
41+
AVG_POOL2D,
42+
CONV2D,
43+
CONV3D,
44+
DEPTHWISE_CONV2D,
45+
FULLY_CONNECTED,
46+
MATMUL,
47+
MAX_POOL2D,
48+
TRANSPOSE_CONV2D,
49+
CLAMP,
50+
RESERVED,
51+
SIGMOID,
52+
TANH,
53+
ADD,
54+
ARITHMETIC_RIGHT_SHIFT,
55+
BITWISE_AND,
56+
BITWISE_OR,
57+
BITWISE_XOR,
58+
INTDIV,
59+
LOGICAL_AND,
60+
LOGICAL_LEFT_SHIFT,
61+
LOGICAL_RIGHT_SHIFT,
62+
LOGICAL_OR,
63+
LOGICAL_XOR,
64+
MAXIMUM,
65+
MINIMUM,
66+
MUL,
67+
POW,
68+
SUB,
69+
TABLE,
70+
ABS,
71+
BITWISE_NOT,
72+
CEIL,
73+
CLZ,
74+
EXP,
75+
FLOOR,
76+
LOG,
77+
LOGICAL_NOT,
78+
NEGATE,
79+
RECIPROCAL,
80+
RSQRT,
81+
SELECT,
82+
EQUAL,
83+
GREATER,
84+
GREATER_EQUAL,
85+
REDUCE_ANY,
86+
REDUCE_ALL,
87+
REDUCE_MAX,
88+
REDUCE_MIN,
89+
REDUCE_PRODUCT,
90+
REDUCE_SUM,
91+
CONCAT,
92+
PAD,
93+
RESHAPE,
94+
REVERSE,
95+
SLICE,
96+
TILE,
97+
TRANSPOSE,
98+
GATHER,
99+
SCATTER,
100+
RESIZE,
101+
CAST,
102+
RESCALE,
103+
CONST,
104+
IDENTITY,
105+
CUSTOM,
106+
COND_IF,
107+
WHILE_LOOP,
108+
FFT2D,
109+
RFFT2D,
110+
ERF,
111+
DIM,
112+
}
113+
114+
union Attribute {
115+
PoolAttribute,
116+
ConvAttribute,
117+
TransposeConvAttribute,
118+
PadAttribute,
119+
AxisAttribute,
120+
ReshapeAttribute,
121+
SliceAttribute,
122+
TileAttribute,
123+
ResizeAttribute,
124+
ClampAttribute,
125+
RescaleAttribute,
126+
MulAttribute,
127+
ArithmeticRightShiftAttribute,
128+
CondIfAttribute,
129+
WhileLoopAttribute,
130+
TransposeAttribute,
131+
TableAttribute,
132+
MatMulAttribute,
133+
FullyConnectedAttribute,
134+
NegateAttribute,
135+
CustomAttribute,
136+
FFTAttribute,
137+
RFFTAttribute,
138+
}
139+
140+
table PoolAttribute {
141+
pad: [int32];
142+
kernel: [int32];
143+
stride: [int32];
144+
input_zp: int32;
145+
output_zp: int32;
146+
accum_dtype: DType;
147+
}
148+
149+
table ConvAttribute {
150+
pad: [int32];
151+
stride: [int32];
152+
dilation: [int32];
153+
input_zp: int32;
154+
weight_zp: int32;
155+
local_bound: bool;
156+
}
157+
158+
table TransposeConvAttribute {
159+
out_pad: [int32];
160+
stride: [int32];
161+
output_shape: [int32];
162+
input_zp: int32;
163+
weight_zp: int32;
164+
local_bound: bool;
165+
}
166+
167+
table PadAttribute {
168+
padding: [int32];
169+
pad_const_int: int32;
170+
pad_const_fp: [ubyte] (force_align: 8);
171+
}
172+
173+
table AxisAttribute {
174+
axis: int32;
175+
}
176+
177+
table ReshapeAttribute {
178+
new_shape: [int32];
179+
}
180+
181+
table SliceAttribute {
182+
start: [int32];
183+
size: [int32];
184+
}
185+
186+
table TileAttribute {
187+
multiples: [int32];
188+
}
189+
190+
table ResizeAttribute {
191+
scale: [int16];
192+
offset: [int16];
193+
border: [int16];
194+
mode: ResizeMode;
195+
}
196+
197+
table ClampAttribute {
198+
min_int: int32;
199+
max_int: int32;
200+
min_fp: [ubyte] (force_align: 8);
201+
max_fp: [ubyte] (force_align: 8);
202+
}
203+
204+
table RescaleAttribute {
205+
input_zp: int32;
206+
output_zp: int32;
207+
multiplier: [int32];
208+
shift: [int32];
209+
scale32: bool;
210+
double_round: bool;
211+
per_channel: bool;
212+
input_unsigned: bool;
213+
output_unsigned: bool;
214+
}
215+
216+
table MulAttribute {
217+
shift: int32;
218+
}
219+
220+
table ArithmeticRightShiftAttribute {
221+
round: bool;
222+
}
223+
224+
table CondIfAttribute {
225+
then_branch: string;
226+
else_branch: string;
227+
}
228+
229+
table WhileLoopAttribute {
230+
cond_branch: string;
231+
body_branch: string;
232+
}
233+
234+
table TransposeAttribute {
235+
perms: [int32];
236+
}
237+
238+
table TableAttribute {
239+
table: [int16];
240+
}
241+
242+
table MatMulAttribute {
243+
a_zp: int32;
244+
b_zp: int32;
245+
}
246+
247+
table FullyConnectedAttribute {
248+
input_zp: int32;
249+
weight_zp: int32;
250+
}
251+
252+
table NegateAttribute {
253+
input1_zp: int32;
254+
output_zp: int32;
255+
}
256+
257+
table CustomAttribute {
258+
operator_name:string;
259+
domain_name:string;
260+
implementation_attrs:[ubyte];
261+
}
262+
263+
table FFTAttribute {
264+
inverse: bool;
265+
local_bound: bool;
266+
}
267+
268+
table RFFTAttribute {
269+
local_bound: bool;
270+
}
271+
272+
table Version {
273+
_major: int32 = -1;
274+
_minor: int32 = -1;
275+
_patch: int32 = -1;
276+
_draft: bool = true;
277+
}
278+
279+
table TosaTensor {
280+
name:string; // name of the tensor, used for solving dependency
281+
shape:[int32]; // shape of the tensor
282+
type:DType; // data type of the tensor
283+
data: [ubyte] (force_align: 8); // raw data array if it's a constant tensor.
284+
variable: bool; // is this a variable tensor
285+
is_unranked: bool; // whether this is an unranked tensor
286+
variable_name:string; // name for variable attribute
287+
}
288+
289+
table TosaOperator {
290+
op:Op; // operator enum
291+
attribute:Attribute; // union structure. operator attribute
292+
inputs:[string]; // list of input tensor names
293+
outputs:[string]; // list of output tensor names
294+
}
295+
296+
table TosaBasicBlock {
297+
name:string; // basic block name
298+
operators:[TosaOperator]; // operators array
299+
tensors:[TosaTensor]; // tensors array
300+
inputs:[string]; // name of graph inputs
301+
outputs:[string]; // name of graph outputs
302+
}
303+
304+
table TosaRegion {
305+
name:string; // name of region
306+
blocks:[TosaBasicBlock]; // basic blocks array
307+
}
308+
309+
table TosaGraph {
310+
version:Version (required);
311+
regions:[TosaRegion]; // regions array
312+
}
313+
314+
root_type TosaGraph;

0 commit comments

Comments
 (0)