@@ -40,9 +40,19 @@ def define_node(
40
40
) -> None :
41
41
# Specification (0.80) states that input and output types
42
42
# should all be the same
43
- assert inputs [0 ].dtype == inputs [1 ].dtype == output .dtype
43
+ if inputs [0 ].dtype != inputs [1 ].dtype or inputs [0 ].dtype != output .dtype :
44
+ raise TypeError (
45
+ f"All IO needs to have the same data type, got input 1: "
46
+ f"{ inputs [0 ].dtype } , input 2: { inputs [1 ].dtype } and output: "
47
+ f"{ output .dtype } "
48
+ )
49
+
44
50
# Handle int8 (quantized) and int32
45
- assert inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]
51
+ supported_dtypes = [ts .DType .INT8 , ts .DType .INT32 ]
52
+ if inputs [0 ].dtype not in supported_dtypes :
53
+ raise TypeError (
54
+ f'IO data type needs to be { supported_dtypes } , got "{ inputs [0 ].dtype } "'
55
+ )
46
56
47
57
if inputs [0 ].dtype == ts .DType .INT8 :
48
58
rescaled_inputs , scale_back = tqutils .insert_rescale_ops_to_int32 (
@@ -97,15 +107,27 @@ def define_node(
97
107
) -> None :
98
108
# Specification (0.80) states that input and output types
99
109
# should all be the same
100
- assert inputs [0 ].dtype == inputs [1 ].dtype == output .dtype
110
+ if inputs [0 ].dtype != inputs [1 ].dtype or inputs [0 ].dtype != output .dtype :
111
+ raise TypeError (
112
+ f"All IO needs to have the same data type, got input 1: "
113
+ f"{ inputs [0 ].dtype } , input 2: { inputs [1 ].dtype } and output: "
114
+ f"{ output .dtype } "
115
+ )
101
116
102
117
if inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]:
103
118
# Call the inherited define_node for handling integers
104
119
super ().define_node (node , tosa_graph , inputs , output )
105
120
else :
106
121
# FP32 Sub lowering
107
- assert inputs [0 ].dtype == ts .DType .FP32
108
- assert output .dtype == ts .DType .FP32
122
+ if (
123
+ inputs [0 ].dtype != ts .DType .FP32
124
+ or inputs [1 ].dtype != ts .DType .FP32
125
+ or output .dtype != ts .DType .FP32
126
+ ):
127
+ raise TypeError (
128
+ f"All IO needs to have data type fp32. Got: { inputs [0 ].dtype } , "
129
+ f"input 2: { inputs [1 ].dtype } and output: { output .dtype } "
130
+ )
109
131
110
132
# MI lowering
111
133
tosa_graph .addOperator (
0 commit comments