11
11
import executorch .backends .arm .tosa_utils as tutils
12
12
13
13
import serializer .tosa_serializer as ts
14
+ import torch
14
15
from executorch .backends .arm .operators .node_visitor import (
15
16
NodeVisitor ,
16
17
register_node_visitor ,
17
18
)
18
19
from executorch .backends .arm .tosa_mapping import TosaArg
20
+ from executorch .backends .arm .tosa_specification import TosaSpecification
19
21
from serializer .tosa_serializer import TosaOp
20
22
from torch .fx import Node
21
23
22
24
23
25
@register_node_visitor
24
- class AddVisitor (NodeVisitor ):
26
+ class AddVisitor_080_BI (NodeVisitor ):
25
27
target = "aten.add.Tensor"
26
28
29
+ tosa_specs = [
30
+ TosaSpecification .create_from_string ("TOSA-0.80.0+BI" ),
31
+ ]
32
+
27
33
def __init__ (self , * args ):
28
34
super ().__init__ (* args )
29
35
@@ -35,9 +41,22 @@ def define_node(
35
41
output : TosaArg ,
36
42
is_quant_node : bool ,
37
43
) -> None :
38
- if is_quant_node :
39
- input_nodes = tutils .get_two_inputs (node )
44
+ input_nodes = tutils .get_two_inputs (node )
45
+
46
+ if not is_quant_node and not all (
47
+ tensor .meta ["val" ].dtype in (torch .int8 , torch .int32 )
48
+ for tensor in input_nodes
49
+ ):
50
+ raise RuntimeError (
51
+ f"Unexpected non quantized { AddVisitor_080_BI .target } node."
52
+ )
40
53
54
+ needs_rescale = not (
55
+ all (tensor .meta ["val" ].dtype == torch .int32 for tensor in input_nodes )
56
+ and node .meta ["val" ].dtype == torch .int32
57
+ )
58
+
59
+ if needs_rescale :
41
60
# Rescale inputs to 32 bit
42
61
rescaled_inputs , scale = tqutils .rescale_nodes_to_int32 (
43
62
input_nodes , tosa_graph
@@ -48,20 +67,48 @@ def define_node(
48
67
rescaled_inputs [0 ].shape , rescaled_inputs [0 ].shape
49
68
)
50
69
add_output = tosa_graph .addIntermediate (broadcasted_shape , ts .DType .INT32 )
70
+ else :
71
+ add_output = output
72
+ rescaled_inputs = inputs
51
73
52
- # Do the INT32 Add
53
- tosa_graph .addOperator (
54
- TosaOp .Op ().ADD ,
55
- [
56
- rescaled_inputs [0 ].name ,
57
- rescaled_inputs [1 ].name ,
58
- ],
59
- [add_output .name ],
60
- None ,
61
- )
74
+ # Do the INT32 Add
75
+ tosa_graph .addOperator (
76
+ TosaOp .Op ().ADD ,
77
+ [
78
+ rescaled_inputs [0 ].name ,
79
+ rescaled_inputs [1 ].name ,
80
+ ],
81
+ [add_output .name ],
82
+ None ,
83
+ )
62
84
85
+ if needs_rescale :
63
86
# Scale output back to 8 bit
64
87
tqutils .rescale_node_back_to_int8 (node , add_output , scale , tosa_graph )
88
+
89
+
90
+ @register_node_visitor
91
+ class AddVisitor_080_MI (AddVisitor_080_BI ):
92
+ # inheriting 'target' from BI class
93
+
94
+ tosa_specs = [
95
+ TosaSpecification .create_from_string ("TOSA-0.80.0+MI" ),
96
+ ]
97
+
98
+ def __init__ (self , * args ):
99
+ super ().__init__ (* args )
100
+
101
+ def define_node (
102
+ self ,
103
+ node : Node ,
104
+ tosa_graph : ts .TosaSerializer ,
105
+ inputs : List [TosaArg ],
106
+ output : TosaArg ,
107
+ is_quant_node : bool ,
108
+ ) -> None :
109
+ if is_quant_node :
110
+ # Call the inherited define_node for handling integers
111
+ super ().define_node (node , tosa_graph , inputs , output , is_quant_node )
65
112
else :
66
113
# FP32 Add lowering
67
114
tosa_graph .addOperator (
0 commit comments