Skip to content

Commit fd87e98

Browse files
Arm backend: Convert remaining asserts in operators to raise errors (#10945)
Asserts are converted to proper raises to ensure graph integrity. Improve error message. Signed-off-by: Sebastian Larsson <[email protected]>
1 parent 5c6d4e5 commit fd87e98

File tree

3 files changed

+28
-12
lines changed

3 files changed

+28
-12
lines changed

backends/arm/operators/op_sub.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,11 @@ def define_node(
163163
validate_same_dtype(self.target, [*inputs, output])
164164

165165
# Handle int8 (quantized) and int32
166-
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]
166+
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
167+
if inputs[0].dtype not in supported_dtypes:
168+
raise TypeError(
169+
f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"'
170+
)
167171

168172
scale_back = 1.0
169173
if inputs[0].dtype == ts.DType.INT8:
@@ -228,8 +232,15 @@ def define_node(
228232
super().define_node(node, tosa_graph, inputs, output)
229233
else:
230234
# FP32 Sub lowering
231-
assert inputs[0].dtype == ts.DType.FP32
232-
assert output.dtype == ts.DType.FP32
235+
if (
236+
inputs[0].dtype != ts.DType.FP32
237+
or inputs[1].dtype != ts.DType.FP32
238+
or output.dtype != ts.DType.FP32
239+
):
240+
raise TypeError(
241+
f"All IO needs to have data type fp32. Got: {inputs[0].dtype}, "
242+
f"input 2: {inputs[1].dtype} and output: {output.dtype}"
243+
)
233244

234245
# MI lowering
235246
tosa_graph.addOperator(

backends/arm/operators/op_upsample_bilinear2d.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,12 @@ def define_node(
153153
def in_int16_range(x):
154154
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)
155155

156-
assert in_int16_range(scale_n_yx)
157-
assert in_int16_range(scale_d_yx)
158-
assert in_int16_range(border_yx)
156+
if not in_int16_range(scale_n_yx):
157+
raise ValueError("scale_n_yx is out of the int16 range")
158+
if not in_int16_range(scale_d_yx):
159+
raise ValueError("scale_d_yx is out of the int16 range")
160+
if not in_int16_range(border_yx):
161+
raise ValueError("border_yx is out of the int16 range")
159162

160163
scales = [scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]]
161164

backends/arm/operators/op_upsample_nearest2d.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,8 @@ def define_node(
102102
validate_num_inputs(self.target, inputs, 3)
103103
validate_same_dtype(self.target, [inputs[0], output])
104104

105-
assert (
106-
inputs[0].shape is not None and output.shape is not None
107-
), "Only static shapes are supported"
105+
if inputs[0].shape is None or output.shape is None:
106+
raise ValueError("Only static shapes are supported")
108107

109108
# tosa_shape output is NHWC, take HW
110109
input_size_yx = torch.tensor(
@@ -121,9 +120,12 @@ def define_node(
121120
def in_int16_range(x):
122121
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)
123122

124-
assert in_int16_range(scale_n_yx)
125-
assert in_int16_range(scale_d_yx)
126-
assert in_int16_range(border_yx)
123+
if not in_int16_range(scale_n_yx):
124+
raise ValueError("scale_n_yx is out of the int16 range")
125+
if not in_int16_range(scale_d_yx):
126+
raise ValueError("scale_d_yx is out of the int16 range")
127+
if not in_int16_range(border_yx):
128+
raise ValueError("border_yx is out of the int16 range")
127129

128130
scales = [scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]]
129131
scales_tensor = tosa_graph.addConst(

0 commit comments

Comments
 (0)