11
11
from typing import Dict , final , List
12
12
13
13
import torch
14
-
15
- from executorch .backends .transforms import get_shape
16
14
from executorch .backends .xnnpack .operators .node_visitor import get_node_visitors
17
15
18
16
from executorch .backends .xnnpack .passes import XNNPACKPassManager
21
19
22
20
from executorch .backends .xnnpack .serialization .xnnpack_graph_schema import (
23
21
Buffer ,
24
- PerChannelQuant ,
25
- PerTensorQuant ,
26
- XNNDatatype ,
27
22
XNNGraph ,
28
- XNNQuantizedTensorValue ,
29
- XNNTensorValue ,
30
- XValue ,
31
23
)
32
24
from executorch .backends .xnnpack .serialization .xnnpack_graph_serialize import (
33
25
serialize_xnnpack_binary ,
34
26
)
35
27
from executorch .backends .xnnpack .utils .utils import is_param_node
36
28
29
+ from executorch .backends .xnnpack .utils .xnnpack_constants import (
30
+ XNN_VALUE_FLAG_EXTERNAL_INPUT ,
31
+ XNN_VALUE_FLAG_EXTERNAL_OUTPUT ,
32
+ )
33
+
37
34
from executorch .exir .backend .backend_details import (
38
35
BackendDetails ,
39
36
CompileSpec ,
42
39
from executorch .exir .verification .verifier import EXIREdgeDialectVerifier
43
40
from torch .export .exported_program import ExportedProgram
44
41
45
- XNN_VALUE_FLAG_NON_EXTERNAL = 0
46
- XNN_VALUE_FLAG_EXTERNAL_INPUT = 1
47
- XNN_VALUE_FLAG_EXTERNAL_OUTPUT = 2
48
- XNN_FLAG_TRANSPOSE_WEIGHTS = 1
49
- XNN_INVALID_VALUE_ID = 2 ** 32 - 1
50
- XNN_TYPE_MAP = {
51
- torch .float32 : XNNDatatype .xnn_datatype_fp32 ,
52
- torch .uint8 : XNNDatatype .xnn_datatype_quint8 ,
53
- torch .int8 : XNNDatatype .xnn_datatype_qint8 ,
54
- torch .int32 : XNNDatatype .xnn_datatype_qint32 ,
55
- }
56
42
DEFAULT_DEBUG_HANDLE = 65535
57
43
58
44
logger = logging .getLogger (__name__ )
@@ -65,97 +51,6 @@ class ExternalMeta:
65
51
io_type : int
66
52
67
53
68
- def node_to_xvalue (
69
- node : torch .fx .Node ,
70
- constant_buffer_idx : int ,
71
- external_id : int ,
72
- flags : int ,
73
- id_out : int ,
74
- dq_datatype = XNNDatatype .xnn_datatype_invalid ,
75
- ) -> XValue :
76
- node_val = node .meta ["val" ]
77
- node_value = XValue (
78
- xvalue_union = XNNTensorValue (
79
- datatype = XNN_TYPE_MAP [node_val .dtype ],
80
- num_dims = node_val .dim (),
81
- dims = get_shape (node ),
82
- constant_buffer_idx = constant_buffer_idx ,
83
- external_id = external_id ,
84
- flags = flags ,
85
- id_out = id_out ,
86
- dq_datatype = dq_datatype ,
87
- )
88
- )
89
- return node_value
90
-
91
-
92
- def node_to_per_tensor_quantized_xvalue (
93
- node : torch .fx .Node ,
94
- dtype : torch .dtype ,
95
- constant_buffer_idx : int ,
96
- external_id : int ,
97
- flags : int ,
98
- id_out : int ,
99
- scale : float ,
100
- zero_point : int ,
101
- ) -> XValue :
102
- node_val = node .meta ["val" ]
103
- node_xvalue = XNNTensorValue (
104
- datatype = XNN_TYPE_MAP [dtype ],
105
- num_dims = node_val .dim (),
106
- dims = get_shape (node ),
107
- constant_buffer_idx = constant_buffer_idx ,
108
- external_id = external_id ,
109
- flags = flags ,
110
- id_out = id_out ,
111
- dq_datatype = XNNDatatype .xnn_datatype_invalid , # always invalid
112
- )
113
-
114
- per_tensor_quantized_params = PerTensorQuant (scale = scale , zero_point = zero_point )
115
- quantized_node_val = XValue (
116
- xvalue_union = XNNQuantizedTensorValue (
117
- tensor_value = node_xvalue ,
118
- quant_params = per_tensor_quantized_params ,
119
- )
120
- )
121
- return quantized_node_val
122
-
123
-
124
- def node_to_per_channel_quantized_xvalue (
125
- node : torch .fx .Node ,
126
- dtype : torch .dtype ,
127
- constant_buffer_idx : int ,
128
- external_id : int ,
129
- flags : int ,
130
- id_out : int ,
131
- channel_dim : int ,
132
- scale : torch .Tensor ,
133
- ) -> XValue :
134
- node_val = node .meta ["val" ]
135
- assert dtype == torch .torch .int8
136
- node_xvalue = XNNTensorValue (
137
- datatype = XNNDatatype .xnn_datatype_qcint8 , # HACK: XNN_TYPE_MAP[dtype],
138
- num_dims = node_val .dim (),
139
- dims = get_shape (node ),
140
- constant_buffer_idx = constant_buffer_idx ,
141
- external_id = external_id ,
142
- flags = flags ,
143
- id_out = id_out ,
144
- dq_datatype = XNNDatatype .xnn_datatype_invalid , # always invalid
145
- )
146
-
147
- per_channel_quantized_params = PerChannelQuant (
148
- scale = scale .tolist (), channel_dim = channel_dim
149
- )
150
- quantized_node_val = XValue (
151
- xvalue_union = XNNQuantizedTensorValue (
152
- tensor_value = node_xvalue ,
153
- quant_params = per_channel_quantized_params ,
154
- )
155
- )
156
- return quantized_node_val
157
-
158
-
159
54
def generate_node_to_external_map (
160
55
exported_program : ExportedProgram ,
161
56
edge_graph_module : torch .fx .GraphModule ,
0 commit comments