9
9
import torch
10
10
from executorch .backends .xnnpack .operators .node_visitor import (
11
11
get_input_node ,
12
- InputTypeToIndex ,
13
12
NodeVisitor ,
14
13
register_node_visitor ,
15
14
)
15
+ from executorch .backends .xnnpack .operators .quant_params import QuantParams
16
16
from executorch .backends .xnnpack .serialization .xnnpack_graph_schema import (
17
+ OutputMinMax ,
17
18
XNNFullyConnected ,
18
19
XNNGraph ,
19
20
XNode ,
20
21
)
22
+ from executorch .backends .xnnpack .utils .utils import get_relu_fused_node
21
23
22
24
from executorch .backends .xnnpack .utils .xnnpack_constants import XNN_INVALID_VALUE_ID
25
+ from executorch .exir .dialects ._ops import ops as exir_ops
23
26
24
27
25
28
@register_node_visitor
@@ -36,30 +39,62 @@ def define_node(
36
39
vals_to_ids : Dict [torch .fx .Node , int ],
37
40
debug_handle : int ,
38
41
) -> None :
39
- input_type_map = (
40
- InputTypeToIndex (node_input = 0 , node_weight = 1 , node_bias = 2 )
41
- if len (node .args ) == 3
42
- else InputTypeToIndex (node_input = 0 , node_weight = 1 )
43
- )
44
- self .define_nodes_tensor_inputs_outputs (
45
- node , xnn_graph , vals_to_ids , input_type_map = input_type_map
46
- )
47
-
48
- # bias
49
- bias_id = (
50
- XNN_INVALID_VALUE_ID
51
- if len (node .args ) == 2
52
- else vals_to_ids [get_input_node (node , input_type_map .node_bias )]
53
- )
54
42
55
43
# input
56
- input_id = vals_to_ids [get_input_node (node , input_type_map .node_input )]
44
+ input_node = get_input_node (node , 0 )
45
+ input_quant_params = QuantParams .from_inputs (input_node , self ._exported_program )
46
+ self .define_tensor (
47
+ input_node ,
48
+ xnn_graph ,
49
+ vals_to_ids ,
50
+ quant_params = input_quant_params ,
51
+ )
52
+ input_id = vals_to_ids [input_node ]
57
53
58
54
# filter
59
- filter_id = vals_to_ids [get_input_node (node , input_type_map .node_weight )]
55
+ weight_node = get_input_node (node , 1 )
56
+ weight_quant_params = QuantParams .from_weights (
57
+ weight_node , self ._exported_program
58
+ )
59
+ self .define_tensor (
60
+ weight_node ,
61
+ xnn_graph ,
62
+ vals_to_ids ,
63
+ quant_params = weight_quant_params ,
64
+ )
65
+ filter_id = vals_to_ids [weight_node ]
66
+
67
+ # bias
68
+ if len (node .args ) > 2 :
69
+ bias_node = get_input_node (node , 2 )
70
+ bias_quant_params = QuantParams .from_bias (
71
+ bias_node , weight_quant_params , input_quant_params
72
+ )
73
+ self .define_tensor (
74
+ get_input_node (node , 2 ),
75
+ xnn_graph ,
76
+ vals_to_ids ,
77
+ quant_params = bias_quant_params ,
78
+ )
79
+ bias_id = vals_to_ids [bias_node ]
80
+ else :
81
+ bias_id = XNN_INVALID_VALUE_ID
60
82
61
83
# output
62
- output_id = vals_to_ids [node ]
84
+ output_node = get_relu_fused_node (node ) or node
85
+ output_min_max = None
86
+ if output_node .target == exir_ops .edge .aten .relu .default :
87
+ output_node .meta ["XNNPACK_FUSED" ] = True
88
+ output_min_max = OutputMinMax (output_min = 0 , output_max = "+inf" )
89
+
90
+ output_quant_params = QuantParams .from_outputs (output_node )
91
+ self .define_tensor (
92
+ output_node ,
93
+ xnn_graph ,
94
+ vals_to_ids ,
95
+ quant_params = output_quant_params ,
96
+ )
97
+ output_id = vals_to_ids [output_node ]
63
98
64
99
ser_node = XNode (
65
100
xnode_union = XNNFullyConnected (
@@ -70,5 +105,6 @@ def define_node(
70
105
flags = 0 ,
71
106
),
72
107
debug_handle = debug_handle ,
108
+ output_min_max = output_min_max ,
73
109
)
74
110
xnn_graph .xnodes .append (ser_node )
0 commit comments