11
11
import executorch .backends .arm .tosa_utils as tutils
12
12
13
13
import serializer .tosa_serializer as ts
14
- import torch
15
14
from executorch .backends .arm .operators .node_visitor import (
16
15
NodeVisitor ,
17
16
register_node_visitor ,
@@ -41,33 +40,27 @@ def define_node(
41
40
output : TosaArg ,
42
41
is_quant_node : bool ,
43
42
) -> None :
44
- input_nodes = tutils .get_two_inputs (node )
45
-
46
- if not is_quant_node and not all (
47
- tensor .meta ["val" ].dtype in (torch .int8 , torch .int32 )
48
- for tensor in input_nodes
49
- ):
50
- raise RuntimeError (
51
- f"Unexpected non quantized { AddVisitor_080_BI .target } node."
52
- )
53
-
54
- needs_rescale = not (
55
- all (tensor .meta ["val" ].dtype == torch .int32 for tensor in input_nodes )
56
- and node .meta ["val" ].dtype == torch .int32
57
- )
58
-
59
- if needs_rescale :
60
- # Rescale inputs to 32 bit
61
- rescaled_inputs , scale = tqutils .rescale_nodes_to_int32 (
62
- input_nodes , tosa_graph
43
+ # Specification (0.80.0) states that input and output types
44
+ # should all be the same
45
+ assert inputs [0 ].dtype == inputs [1 ].dtype == output .dtype
46
+ # Handle int8 (quantized) and int32
47
+ assert inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]
48
+
49
+ if inputs [0 ].dtype == ts .DType .INT8 :
50
+ rescaled_inputs , scale_back = tqutils .insert_rescale_ops_to_int32 (
51
+ tosa_graph , inputs , node
63
52
)
53
+ else :
54
+ # input[0].dtype == ts.DType.INT32
55
+ # Non quantized input, natively support by TOSA.ADD
56
+ rescaled_inputs = inputs
64
57
65
- # Prepare add output tensor
58
+ if output . dtype == ts . DType . INT8 :
66
59
broadcasted_shape = tutils .tosa_shape (output .shape , output .dim_order )
67
60
add_output = tosa_graph .addIntermediate (broadcasted_shape , ts .DType .INT32 )
68
61
else :
62
+ # output.dtype == ts.DType.INT32
69
63
add_output = output
70
- rescaled_inputs = inputs
71
64
72
65
# Do the INT32 Add
73
66
tosa_graph .addOperator (
@@ -80,10 +73,12 @@ def define_node(
80
73
None ,
81
74
)
82
75
83
- if needs_rescale :
76
+ if output . dtype == ts . DType . INT8 :
84
77
# Scale output back to 8 bit
85
78
# pyre-ignore
86
- tqutils .rescale_node_back_to_int8 (node , add_output , scale , tosa_graph )
79
+ tqutils .insert_rescale_node_back_to_int8 (
80
+ tosa_graph , add_output , scale_back , node
81
+ )
87
82
88
83
89
84
@register_node_visitor
@@ -105,11 +100,19 @@ def define_node(
105
100
output : TosaArg ,
106
101
is_quant_node : bool ,
107
102
) -> None :
108
- if is_quant_node :
103
+ # Specification (0.80.0) states that input and output types
104
+ # should all be the same
105
+ assert inputs [0 ].dtype == inputs [1 ].dtype == output .dtype
106
+
107
+ if inputs [0 ].dtype in [ts .DType .INT8 , ts .DType .INT32 ]:
109
108
# Call the inherited define_node for handling integers
110
109
super ().define_node (node , tosa_graph , inputs , output , is_quant_node )
111
110
else :
112
111
# FP32 Add lowering
112
+ assert inputs [0 ].dtype == ts .DType .FP32
113
+ assert output .dtype == ts .DType .FP32
114
+
115
+ # MI lowering
113
116
tosa_graph .addOperator (
114
117
TosaOp .Op ().ADD ,
115
118
[inputs [0 ].name , inputs [1 ].name ],
0 commit comments