6
6
7
7
import operator
8
8
import re
9
+ import time
9
10
import warnings
10
11
from collections import OrderedDict
11
12
from typing import Any , Callable , Dict , FrozenSet , List , Optional , Tuple
@@ -740,17 +741,17 @@ def preprocess_binary(ctx_bin, compiler_specs):
740
741
for k , v in type_map .items ():
741
742
dtype_map .setdefault (v , k )
742
743
743
- qnn_in_order , executorch_in_order , executorch_out_order = [], [], []
744
+ qnn_in_order , executorch_in_order , executorch_out_order = None , None , None
744
745
if custom_info is not None :
745
746
# since some context binaries might fail to open on host
746
747
# if they are compiled with special flags:
747
748
# e.g. weight sharing
748
749
# use custom information here instead
749
750
inputs = build_tensor (custom_info ["graph_inputs" ], dtype_map )
750
751
outputs = build_tensor (custom_info ["graph_outputs" ], dtype_map )
751
- qnn_in_order = custom_info [ "qnn_in_order" ]
752
- executorch_in_order = custom_info [ "executorch_in_order" ]
753
- executorch_out_order = custom_info [ "executorch_out_order" ]
752
+ qnn_in_order = custom_info . get ( "qnn_in_order" , None )
753
+ executorch_in_order = custom_info . get ( "executorch_in_order" , None )
754
+ executorch_out_order = custom_info . get ( "executorch_out_order" , None )
754
755
graph_name = custom_info ["graph_name" ]
755
756
else :
756
757
# get context-binary io tensor info through qnn manager
@@ -800,7 +801,9 @@ def draw_graph(title, path, graph_module: torch.fx.GraphModule):
800
801
801
802
def generate_multi_graph_program (
802
803
compiler_specs : List [CompileSpec ],
803
- exported_programs : List [ExportedProgram ] = None ,
804
+ processed_bytes : List [bytes ],
805
+ input_nodes_dict : List [torch .fx .Node ] = None ,
806
+ output_nodes_dict : List [torch .fx .Node ] = None ,
804
807
backend_config : ExecutorchBackendConfig = None ,
805
808
constant_methods : Optional [Dict [str , Any ]] = None ,
806
809
) -> ExecutorchProgramManager :
@@ -813,10 +816,6 @@ def generate_multi_graph_program(
813
816
executorch_in_order ,
814
817
executorch_out_order ,
815
818
) = ({}, {}, {}, {}, {})
816
-
817
- processed_bytes = [
818
- prog .graph_module .lowered_module_0 .processed_bytes for prog in exported_programs
819
- ]
820
819
qnn_mgr = PyQnnManagerAdaptor .QnnManager (
821
820
generate_qnn_executorch_option (compiler_specs ), processed_bytes
822
821
)
@@ -829,38 +828,36 @@ def generate_multi_graph_program(
829
828
graph_outputs [graph_name ] = qnn_mgr .GetGraphOutputs (graph_name )
830
829
831
830
# We need to obtain the order of the IOs to correctly map QNN with nn.module
832
- for i , graph_name in enumerate (graph_names ):
833
- # input
834
- input_names = [
835
- node .name
836
- for node in exported_programs [i ].graph_module .graph .nodes
837
- if node .op == "placeholder"
838
- ]
839
- qnn_input_names = [wrapper .GetName () for wrapper in graph_inputs [graph_name ]]
840
- input_order_list = []
841
- for input_name in input_names :
842
- # e.g., input_0_tokens_0
843
- pattern = rf"^input_(\d+)_({ input_name } )_(\d+)$"
844
- for j in range (len (qnn_input_names )):
845
- if re .match (pattern , qnn_input_names [j ]):
846
- input_order_list .append (j )
847
- break
848
- assert (
849
- len (input_order_list ) == len (input_names ) == len (qnn_input_names )
850
- ), "Order list length is different from names"
851
- executorch_in_order [graph_name ] = input_order_list
852
- qnn_in_order [graph_name ] = sorted (
853
- range (len (input_order_list )), key = lambda k : input_order_list [k ]
854
- )
855
-
856
- # output
857
- get_item_list = [
858
- node
859
- for node in exported_programs [i ].graph_module .graph .nodes
860
- if node .op == "output"
861
- ][0 ].args [0 ]
862
- output_order_list = [item .args [1 ] for item in get_item_list ]
863
- executorch_out_order [graph_name ] = output_order_list
831
+ for graph_name in graph_names :
832
+ if input_nodes_dict :
833
+ # input
834
+ input_names = [node .name for node in input_nodes_dict [graph_name ]]
835
+ qnn_input_names = [
836
+ wrapper .GetName () for wrapper in graph_inputs [graph_name ]
837
+ ]
838
+ # The input of intermideate module including call_function node
839
+ # could not be reorder by node name
840
+ if len (input_names ) == len (qnn_input_names ):
841
+ input_order_list = []
842
+ for input_name in input_names :
843
+ # e.g., input_0_tokens_0
844
+ pattern = rf"^input_(\d+)_({ input_name } )_(\d+)$"
845
+ for j in range (len (qnn_input_names )):
846
+ if re .match (pattern , qnn_input_names [j ]):
847
+ input_order_list .append (j )
848
+ break
849
+ assert len (input_order_list ) == len (
850
+ input_names
851
+ ), "Order list length is different from names"
852
+ executorch_in_order [graph_name ] = input_order_list
853
+ qnn_in_order [graph_name ] = sorted (
854
+ range (len (input_order_list )), key = lambda k : input_order_list [k ]
855
+ )
856
+ if output_nodes_dict :
857
+ # output
858
+ get_item_list = output_nodes_dict [graph_name ][0 ].args [0 ]
859
+ output_order_list = [item .args [1 ] for item in get_item_list ]
860
+ executorch_out_order [graph_name ] = output_order_list
864
861
865
862
qnn_mgr .Destroy ()
866
863
@@ -869,15 +866,15 @@ def generate_multi_graph_program(
869
866
bundle_progs = [
870
867
from_context_binary (
871
868
ctx_path = binary_info ,
872
- op_name = f"loader_{ graph_name } " ,
869
+ op_name = f"loader_{ graph_name } _ { int ( time . time ()) } " ,
873
870
soc_model = compiler_options .soc_info .soc_model ,
874
871
custom_info = {
875
872
"graph_inputs" : graph_inputs [graph_name ],
876
873
"graph_outputs" : graph_outputs [graph_name ],
877
874
"graph_name" : graph_name ,
878
- "qnn_in_order" : qnn_in_order [ graph_name ] ,
879
- "executorch_in_order" : executorch_in_order [ graph_name ] ,
880
- "executorch_out_order" : executorch_out_order [ graph_name ] ,
875
+ "qnn_in_order" : qnn_in_order . get ( graph_name , None ) ,
876
+ "executorch_in_order" : executorch_in_order . get ( graph_name , None ) ,
877
+ "executorch_out_order" : executorch_out_order . get ( graph_name , None ) ,
881
878
},
882
879
)
883
880
for graph_name in graph_names
@@ -900,9 +897,101 @@ def generate_multi_graph_program(
900
897
break
901
898
902
899
edge_prog_mgr = edge_prog_mgr .to_backend (QnnPartitioner (compiler_specs ))
903
- return edge_prog_mgr .to_executorch (
900
+ exec_prog = edge_prog_mgr .to_executorch (
901
+ config = backend_config or ExecutorchBackendConfig ()
902
+ )
903
+ return exec_prog , bundle_progs
904
+
905
+
906
+ def generate_composite_llama_program (
907
+ graph_names : List [str ],
908
+ sample_inputs_list : List [Tuple [Any ]],
909
+ lower_module_dict : Dict [str , List [LoweredBackendModule ]],
910
+ call_delegate_node_name_dict : Dict [str , List [str ]],
911
+ call_delegate_inputs_dict : Dict [str , List [Tuple [str , int | None ]]],
912
+ outputs_dict : Dict [str , List [Tuple [str , int ]]],
913
+ backend_config : ExecutorchBackendConfig = None ,
914
+ constant_methods : Optional [Dict [str , Any ]] = None ,
915
+ ) -> ExecutorchProgramManager :
916
+ class CompositeLlamaModule (torch .nn .Module ):
917
+ def __init__ (
918
+ self ,
919
+ lower_module_list ,
920
+ call_delegate_node_name_list ,
921
+ call_delegate_inputs_list ,
922
+ outputs_list ,
923
+ ) -> None :
924
+ super ().__init__ ()
925
+ self .lower_module_list = lower_module_list
926
+ self .call_delegate_node_name_list = call_delegate_node_name_list
927
+ self .call_delegate_inputs_list = call_delegate_inputs_list
928
+ self .outputs_list = outputs_list
929
+
930
+ def reorder (
931
+ self ,
932
+ call_delegate_inputs : List [Tuple [str , int | None ]],
933
+ module_inputs : dict [str , torch .Tensor ],
934
+ all_ret : dict [str , torch .Tensor ],
935
+ ) -> Tuple [torch .Tensor ]:
936
+ ret = []
937
+ for name , index in call_delegate_inputs :
938
+ if index is not None :
939
+ # Get tensor from previous results
940
+ ret .append (all_ret [name ][index ])
941
+ else :
942
+ # Get tensor from the inputs of module
943
+ ret .append (module_inputs [name ])
944
+ return tuple (ret )
945
+
946
+ def forward (
947
+ self ,
948
+ tokens : torch .Tensor ,
949
+ atten_mask : torch .Tensor ,
950
+ input_pos : Optional [torch .Tensor ] = None ,
951
+ * args ,
952
+ ) -> Tuple [torch .Tensor ]:
953
+ all_ret = {}
954
+ module_input_dict = {
955
+ "tokens" : tokens ,
956
+ "atten_mask" : atten_mask ,
957
+ "input_pos" : input_pos ,
958
+ }
959
+ for num , arg in enumerate (args ):
960
+ module_input_dict [f"args_{ num } " ] = arg
961
+ for lower_module , call_delegate_node_name , call_delegate_inputs in zip (
962
+ self .lower_module_list ,
963
+ self .call_delegate_node_name_list ,
964
+ self .call_delegate_inputs_list ,
965
+ ):
966
+ inp = self .reorder (call_delegate_inputs , module_input_dict , all_ret )
967
+ ret = lower_module (* inp )
968
+ all_ret [call_delegate_node_name ] = ret
969
+ llama_outputs = []
970
+ for output_src_name , index in self .outputs_list :
971
+ llama_outputs .append (all_ret [output_src_name ][index ])
972
+ return tuple (llama_outputs )
973
+
974
+ progs_dict = {}
975
+ for graph_name , sample_inputs in zip (graph_names , sample_inputs_list ):
976
+ composite_llama_module = CompositeLlamaModule (
977
+ lower_module_dict [graph_name ],
978
+ call_delegate_node_name_dict [graph_name ],
979
+ call_delegate_inputs_dict [graph_name ],
980
+ outputs_dict [graph_name ],
981
+ )
982
+ prog = torch .export .export (composite_llama_module , sample_inputs )
983
+ progs_dict [graph_name ] = prog
984
+ # leverage ExecutorchProgramManager for generating pte with multi-methods
985
+ edge_prog_mgr = to_edge (
986
+ progs_dict ,
987
+ constant_methods = constant_methods ,
988
+ # do not alter name for custom op
989
+ compile_config = EdgeCompileConfig (_check_ir_validity = False , _use_edge_ops = False ),
990
+ )
991
+ exec_prog = edge_prog_mgr .to_executorch (
904
992
config = backend_config or ExecutorchBackendConfig ()
905
993
)
994
+ return exec_prog
906
995
907
996
908
997
def generate_htp_compiler_spec (
0 commit comments