14
14
from executorch .backends .arm .operators .node_visitor import NodeVisitor
15
15
from executorch .backends .arm .tosa_mapping import TosaArg
16
16
from executorch .backends .arm .tosa_specification import TosaSpecification
17
- from executorch .backends .arm .tosa_utils import getNodeArgs , tosa_shape
17
+ from executorch .backends .arm .tosa_utils import (
18
+ get_node_debug_info ,
19
+ getNodeArgs ,
20
+ tosa_shape ,
21
+ )
18
22
from torch .export .exported_program import ExportedProgram
19
23
20
24
@@ -28,8 +32,13 @@ def process_call_function(
28
32
inputs = getNodeArgs (node )
29
33
30
34
# Convert output (this node itself)
31
- output = TosaArg (node )
32
-
35
+ try :
36
+ output = TosaArg (node )
37
+ except ValueError as e :
38
+ raise ValueError (
39
+ f"Failed processing call_function:\n { get_node_debug_info (node )} "
40
+ "Is the original torch function supported?"
41
+ ) from e
33
42
tosa_graph .currRegion .currBasicBlock .addTensor (
34
43
output .name , tosa_shape (output .shape , output .dim_order ), output .dtype
35
44
)
@@ -61,15 +70,21 @@ def process_inputs(
61
70
f"Arm backend only supports contiguous memory format for inputs. "
62
71
f"Expected dim_order: { tuple (range (meta .dim ()))} , but got: { meta .dim_order ()} for node { node .name } "
63
72
)
64
- inputs = [TosaArg (node )]
65
- input_shape = inputs [0 ].shape
66
- input_dim_order = inputs [0 ].dim_order
73
+ try :
74
+ tosa_arg = TosaArg (node )
75
+ except ValueError as e :
76
+ raise ValueError (
77
+ f"Failed processing input placeholder:\n { get_node_debug_info (node )} "
78
+ "Is the original torch function supported?"
79
+ ) from e
80
+ input_shape = tosa_arg .shape
81
+ input_dim_order = tosa_arg .dim_order
67
82
tensor = ts .TosaSerializerTensor (
68
- inputs [ 0 ] .name ,
83
+ tosa_arg .name ,
69
84
tosa_shape (input_shape , input_dim_order ),
70
- inputs [ 0 ] .dtype ,
85
+ tosa_arg .dtype ,
71
86
data = None ,
72
- placeholderFilename = inputs [ 0 ] .name + ".npy" ,
87
+ placeholderFilename = tosa_arg .name + ".npy" ,
73
88
)
74
89
tosa_graph .addInputTensor (tensor )
75
90
@@ -81,20 +96,26 @@ def process_inputs_to_parameters(
81
96
tosa_spec : TosaSpecification ,
82
97
):
83
98
"""Serialize bias and non-quantized weights"""
84
- inputs = [TosaArg (node )]
85
- parameter_name = edge_program .graph_signature .inputs_to_parameters [node .name ]
99
+ try :
100
+ tosa_arg = TosaArg (node )
101
+ except ValueError as e :
102
+ raise ValueError (
103
+ f"Failed processing parameter placeholder:\n { get_node_debug_info (node )} "
104
+ "Is the original torch function supported?"
105
+ ) from e
106
+ parameter_name = edge_program .graph_signature .inputs_to_parameters [tosa_arg .name ]
86
107
parameter_data = edge_program .state_dict [parameter_name ]
87
108
88
109
assert isinstance (parameter_data , torch .Tensor ), "Expect Attr to be tensor"
89
110
parameter_values = parameter_data .detach ().numpy ()
90
111
91
- if inputs [ 0 ] .dtype == torch .float32 :
112
+ if tosa_arg .dtype == torch .float32 :
92
113
assert tosa_spec .support_float (), f"{ tosa_spec } doesn't support float"
93
114
94
- parameter_values = np .transpose (parameter_values , inputs [ 0 ] .dim_order )
115
+ parameter_values = np .transpose (parameter_values , tosa_arg .dim_order )
95
116
96
117
tosa_graph .addConst (
97
- parameter_values .shape , inputs [ 0 ] .dtype , parameter_values , name = node .name
118
+ parameter_values .shape , tosa_arg .dtype , parameter_values , name = tosa_arg .name
98
119
)
99
120
100
121
@@ -104,7 +125,13 @@ def process_inputs_to_buffers(
104
125
edge_program : ExportedProgram ,
105
126
):
106
127
"""Serialize quantized weights"""
107
- inputs = [TosaArg (node )]
128
+ try :
129
+ tosa_arg = TosaArg (node )
130
+ except ValueError as e :
131
+ raise ValueError (
132
+ f"Failed processing buffer placeholder:\n { get_node_debug_info (node )} "
133
+ "Is the original torch function supported?"
134
+ ) from e
108
135
buffer_name = edge_program .graph_signature .inputs_to_buffers [node .name ]
109
136
buffer_data = edge_program .state_dict [buffer_name ]
110
137
@@ -114,10 +141,10 @@ def process_inputs_to_buffers(
114
141
# TODO: fragile code for temporary fix
115
142
# the mean and var tensors are also stored here but they have shape (1, )
116
143
# we only transpose weights here
117
- buffer_values = np .transpose (buffer_values , inputs [ 0 ] .dim_order )
144
+ buffer_values = np .transpose (buffer_values , tosa_arg .dim_order )
118
145
119
146
tosa_graph .addConst (
120
- buffer_values .shape , inputs [ 0 ] .dtype , buffer_values , name = node .name
147
+ buffer_values .shape , tosa_arg .dtype , buffer_values , name = node .name
121
148
)
122
149
123
150
@@ -126,14 +153,22 @@ def process_inputs_to_lifted_tensor_constants(
126
153
tosa_graph : ts .TosaSerializer ,
127
154
edge_program : ExportedProgram ,
128
155
):
129
- arg = TosaArg (node )
156
+ try :
157
+ tosa_arg = TosaArg (node )
158
+ except ValueError as e :
159
+ raise ValueError (
160
+ f"Failed processing lifted tensor constant placeholder:\n { get_node_debug_info (node )} "
161
+ "Is the original torch function supported?"
162
+ ) from e
130
163
tensor_name = edge_program .graph_signature .inputs_to_lifted_tensor_constants [
131
- arg .name
164
+ tosa_arg .name
132
165
]
133
166
tensor = edge_program .tensor_constants [tensor_name ]
134
167
tensor_data = tensor .detach ().numpy ()
135
168
136
- tosa_graph .addConst (tensor_data .shape , arg .dtype , tensor_data , name = arg .name )
169
+ tosa_graph .addConst (
170
+ tensor_data .shape , tosa_arg .dtype , tensor_data , name = tosa_arg .name
171
+ )
137
172
138
173
139
174
def process_placeholder (
0 commit comments