11
11
import executorch .backends .arm .tosa_utils as tutils
12
12
13
13
import serializer .tosa_serializer as ts
14
+ import torch
14
15
from executorch .backends .arm .operators .node_visitor import (
15
16
NodeVisitor ,
16
17
register_node_visitor ,
@@ -40,27 +41,33 @@ def define_node(
40
41
output : TosaArg ,
41
42
is_quant_node : bool ,
42
43
) -> None :
43
- # Specification (0.80) states that input and output types
44
- # should all be the same
45
- assert inputs [0 ].dtype == inputs [1 ].dtype == output .dtype
46
- # Handle int8 (quantized) and int32
47
- assert inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]
48
-
49
- if inputs [0 ].dtype == ts .DType .INT8 :
50
- rescaled_inputs , scale_back = tqutils .insert_rescale_ops_to_int32 (
51
- tosa_graph , inputs , node
44
+ input_nodes = tutils .get_two_inputs (node )
45
+
46
+ if not is_quant_node and not all (
47
+ tensor .meta ["val" ].dtype in (torch .int8 , torch .int32 )
48
+ for tensor in input_nodes
49
+ ):
50
+ raise RuntimeError (
51
+ f"Unexpected non quantized { AddVisitor_080_BI .target } node."
52
+ )
53
+
54
+ needs_rescale = not (
55
+ all (tensor .meta ["val" ].dtype == torch .int32 for tensor in input_nodes )
56
+ and node .meta ["val" ].dtype == torch .int32
57
+ )
58
+
59
+ if needs_rescale :
60
+ # Rescale inputs to 32 bit
61
+ rescaled_inputs , scale = tqutils .rescale_nodes_to_int32 (
62
+ input_nodes , tosa_graph
52
63
)
53
- else :
54
- # input[0].dtype == ts.DType.INT32
55
- # Non quantized input, natively support by TOSA.ADD
56
- rescaled_inputs = inputs
57
64
58
- if output . dtype == ts . DType . INT8 :
65
+ # Prepare add output tensor
59
66
broadcasted_shape = tutils .tosa_shape (output .shape , output .dim_order )
60
67
add_output = tosa_graph .addIntermediate (broadcasted_shape , ts .DType .INT32 )
61
68
else :
62
- # output.dtype == ts.DType.INT32
63
69
add_output = output
70
+ rescaled_inputs = inputs
64
71
65
72
# Do the INT32 Add
66
73
tosa_graph .addOperator (
@@ -73,12 +80,10 @@ def define_node(
73
80
None ,
74
81
)
75
82
76
- if output . dtype == ts . DType . INT8 :
83
+ if needs_rescale :
77
84
# Scale output back to 8 bit
78
85
# pyre-ignore
79
- tqutils .insert_rescale_node_back_to_int8 (
80
- tosa_graph , add_output , scale_back , node
81
- )
86
+ tqutils .rescale_node_back_to_int8 (node , add_output , scale , tosa_graph )
82
87
83
88
84
89
@register_node_visitor
@@ -100,19 +105,11 @@ def define_node(
100
105
output : TosaArg ,
101
106
is_quant_node : bool ,
102
107
) -> None :
103
- # Specification (0.80) states that input and output types
104
- # should all be the same
105
- assert inputs [0 ].dtype == inputs [1 ].dtype == output .dtype
106
-
107
- if inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]:
108
+ if is_quant_node :
108
109
# Call the inherited define_node for handling integers
109
110
super ().define_node (node , tosa_graph , inputs , output , is_quant_node )
110
111
else :
111
112
# FP32 Add lowering
112
- assert inputs [0 ].dtype == ts .DType .FP32
113
- assert output .dtype == ts .DType .FP32
114
-
115
- # MI lowering
116
113
tosa_graph .addOperator (
117
114
TosaOp .Op ().ADD ,
118
115
[inputs [0 ].name , inputs [1 ].name ],
0 commit comments