5
5
6
6
import numpy as np
7
7
import serializer .tosa_serializer as ts
8
- import torch
8
+ import torch . fx
9
9
from executorch .backends .arm .tosa_mapping import TosaArg
10
10
from executorch .backends .arm .tosa_quant_utils import (
11
11
get_quant_arg_dtype ,
@@ -130,6 +130,21 @@ def process_inputs_to_buffers(
130
130
)
131
131
132
132
133
+ def process_inputs_to_lifted_tensor_constants (
134
+ node : torch .fx .Node ,
135
+ tosa_graph : ts .TosaSerializer ,
136
+ edge_program : ExportedProgram ,
137
+ ):
138
+ arg = TosaArg (node )
139
+ tensor_name = edge_program .graph_signature .inputs_to_lifted_tensor_constants [
140
+ arg .name
141
+ ]
142
+ tensor = edge_program .tensor_constants [tensor_name ]
143
+ tensor_data = tensor .detach ().numpy ()
144
+
145
+ tosa_graph .addConst (tensor_data .shape , arg .dtype , tensor_data , name = arg .name )
146
+
147
+
133
148
def process_placeholder (
134
149
node : torch .fx .Node ,
135
150
tosa_graph : ts .TosaSerializer ,
@@ -145,5 +160,11 @@ def process_placeholder(
145
160
process_inputs_to_parameters (node , tosa_graph , edge_program )
146
161
elif node .name in edge_program .graph_signature .inputs_to_buffers :
147
162
process_inputs_to_buffers (node , tosa_graph , edge_program )
163
+ elif node .name in edge_program .graph_signature .inputs_to_lifted_tensor_constants :
164
+ process_inputs_to_lifted_tensor_constants (node , tosa_graph , edge_program )
165
+ elif node .name in edge_program .graph_signature .inputs_to_lifted_custom_objs :
166
+ raise NotImplementedError (
167
+ "Placeholder is of type 'lifted custom object' which is not supported."
168
+ )
148
169
else :
149
- raise RuntimeError (f"Unknown placeholder { node .name } " )
170
+ raise RuntimeError (f"Placeholder ' { node .name } ' is of unknown type. " )
0 commit comments