Skip to content

Commit e29a4b5

Browse files
Arm backend: Convert assert to raise TypeError in op_sub (#9958)
Asserts are converted to proper raises to ensure graph integrity. Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 8e08b7c commit e29a4b5

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

backends/arm/operators/op_sub.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,19 @@ def define_node(
4040
) -> None:
4141
# Specification (0.80) states that input and output types
4242
# should all be the same
43-
assert inputs[0].dtype == inputs[1].dtype == output.dtype
43+
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
44+
raise TypeError(
45+
f"All IO needs to have the same data type, got input 1: "
46+
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
47+
f"{output.dtype}"
48+
)
49+
4450
# Handle int8 (quantized) and int32
45-
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]
51+
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
52+
if inputs[0].dtype not in supported_dtypes:
53+
raise TypeError(
54+
f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"'
55+
)
4656

4757
if inputs[0].dtype == ts.DType.INT8:
4858
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
@@ -97,15 +107,27 @@ def define_node(
97107
) -> None:
98108
# Specification (0.80) states that input and output types
99109
# should all be the same
100-
assert inputs[0].dtype == inputs[1].dtype == output.dtype
110+
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
111+
raise TypeError(
112+
f"All IO needs to have the same data type, got input 1: "
113+
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
114+
f"{output.dtype}"
115+
)
101116

102117
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
103118
# Call the inherited define_node for handling integers
104119
super().define_node(node, tosa_graph, inputs, output)
105120
else:
106121
# FP32 Sub lowering
107-
assert inputs[0].dtype == ts.DType.FP32
108-
assert output.dtype == ts.DType.FP32
122+
if (
123+
inputs[0].dtype != ts.DType.FP32
124+
or inputs[1].dtype != ts.DType.FP32
125+
or output.dtype != ts.DType.FP32
126+
):
127+
raise TypeError(
128+
f"All IO needs to have data type fp32. Got: {inputs[0].dtype}, "
129+
f"input 2: {inputs[1].dtype} and output: {output.dtype}"
130+
)
109131

110132
# MI lowering
111133
tosa_graph.addOperator(

0 commit comments

Comments
 (0)