Skip to content

Commit 9def8ff

Browse files
committed
enable sharding in hybrid mode
1 parent 97ab146 commit 9def8ff

File tree

3 files changed

+273
-63
lines changed

3 files changed

+273
-63
lines changed

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1655,8 +1655,12 @@ def test_qnn_backend_multi_graphs(self):
16551655
to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i]))
16561656
for i, edge_prog in enumerate(edge_progs)
16571657
]
1658-
prog_mgr = generate_multi_graph_program(
1659-
compiler_specs=compiler_specs[0], exported_programs=exported_programs
1658+
prog_mgr, _ = generate_multi_graph_program(
1659+
compiler_specs=compiler_specs[0],
1660+
processed_bytes=[
1661+
prog.graph_module.lowered_module_0.processed_bytes
1662+
for prog in exported_programs
1663+
],
16601664
)
16611665
for index, module in enumerate(modules):
16621666
self.verify_output(
@@ -2120,9 +2124,12 @@ def test_qnn_backend_multi_graphs(self):
21202124
to_backend(edge_prog.exported_program, QnnPartitioner(compiler_specs[i]))
21212125
for i, edge_prog in enumerate(edge_progs)
21222126
]
2123-
prog_mgr = generate_multi_graph_program(
2127+
prog_mgr, _ = generate_multi_graph_program(
21242128
compiler_specs=compiler_specs[0],
2125-
exported_programs=exported_programs,
2129+
processed_bytes=[
2130+
prog.graph_module.lowered_module_0.processed_bytes
2131+
for prog in exported_programs
2132+
],
21262133
)
21272134
for index, module in enumerate(modules):
21282135
self.verify_output(

backends/qualcomm/utils/utils.py

Lines changed: 135 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import operator
88
import re
9+
import time
910
import warnings
1011
from collections import OrderedDict
1112
from typing import Any, Callable, Dict, FrozenSet, List, Optional, Tuple
@@ -740,17 +741,17 @@ def preprocess_binary(ctx_bin, compiler_specs):
740741
for k, v in type_map.items():
741742
dtype_map.setdefault(v, k)
742743

743-
qnn_in_order, executorch_in_order, executorch_out_order = [], [], []
744+
qnn_in_order, executorch_in_order, executorch_out_order = None, None, None
744745
if custom_info is not None:
745746
# since some context binaries might fail to open on host
746747
# if they are compiled with special flags:
747748
# e.g. weight sharing
748749
# use custom information here instead
749750
inputs = build_tensor(custom_info["graph_inputs"], dtype_map)
750751
outputs = build_tensor(custom_info["graph_outputs"], dtype_map)
751-
qnn_in_order = custom_info["qnn_in_order"]
752-
executorch_in_order = custom_info["executorch_in_order"]
753-
executorch_out_order = custom_info["executorch_out_order"]
752+
qnn_in_order = custom_info.get("qnn_in_order", None)
753+
executorch_in_order = custom_info.get("executorch_in_order", None)
754+
executorch_out_order = custom_info.get("executorch_out_order", None)
754755
graph_name = custom_info["graph_name"]
755756
else:
756757
# get context-binary io tensor info through qnn manager
@@ -800,7 +801,9 @@ def draw_graph(title, path, graph_module: torch.fx.GraphModule):
800801

801802
def generate_multi_graph_program(
802803
compiler_specs: List[CompileSpec],
803-
exported_programs: List[ExportedProgram] = None,
804+
processed_bytes: List[bytes],
805+
input_nodes_dict: List[torch.fx.Node] = None,
806+
output_nodes_dict: List[torch.fx.Node] = None,
804807
backend_config: ExecutorchBackendConfig = None,
805808
constant_methods: Optional[Dict[str, Any]] = None,
806809
) -> ExecutorchProgramManager:
@@ -813,10 +816,6 @@ def generate_multi_graph_program(
813816
executorch_in_order,
814817
executorch_out_order,
815818
) = ({}, {}, {}, {}, {})
816-
817-
processed_bytes = [
818-
prog.graph_module.lowered_module_0.processed_bytes for prog in exported_programs
819-
]
820819
qnn_mgr = PyQnnManagerAdaptor.QnnManager(
821820
generate_qnn_executorch_option(compiler_specs), processed_bytes
822821
)
@@ -829,38 +828,36 @@ def generate_multi_graph_program(
829828
graph_outputs[graph_name] = qnn_mgr.GetGraphOutputs(graph_name)
830829

831830
# We need to obtain the order of the IOs to correctly map QNN with nn.module
832-
for i, graph_name in enumerate(graph_names):
833-
# input
834-
input_names = [
835-
node.name
836-
for node in exported_programs[i].graph_module.graph.nodes
837-
if node.op == "placeholder"
838-
]
839-
qnn_input_names = [wrapper.GetName() for wrapper in graph_inputs[graph_name]]
840-
input_order_list = []
841-
for input_name in input_names:
842-
# e.g., input_0_tokens_0
843-
pattern = rf"^input_(\d+)_({input_name})_(\d+)$"
844-
for j in range(len(qnn_input_names)):
845-
if re.match(pattern, qnn_input_names[j]):
846-
input_order_list.append(j)
847-
break
848-
assert (
849-
len(input_order_list) == len(input_names) == len(qnn_input_names)
850-
), "Order list length is different from names"
851-
executorch_in_order[graph_name] = input_order_list
852-
qnn_in_order[graph_name] = sorted(
853-
range(len(input_order_list)), key=lambda k: input_order_list[k]
854-
)
855-
856-
# output
857-
get_item_list = [
858-
node
859-
for node in exported_programs[i].graph_module.graph.nodes
860-
if node.op == "output"
861-
][0].args[0]
862-
output_order_list = [item.args[1] for item in get_item_list]
863-
executorch_out_order[graph_name] = output_order_list
831+
for graph_name in graph_names:
832+
if input_nodes_dict:
833+
# input
834+
input_names = [node.name for node in input_nodes_dict[graph_name]]
835+
qnn_input_names = [
836+
wrapper.GetName() for wrapper in graph_inputs[graph_name]
837+
]
838+
# The input of intermideate module including call_function node
839+
# could not be reorder by node name
840+
if len(input_names) == len(qnn_input_names):
841+
input_order_list = []
842+
for input_name in input_names:
843+
# e.g., input_0_tokens_0
844+
pattern = rf"^input_(\d+)_({input_name})_(\d+)$"
845+
for j in range(len(qnn_input_names)):
846+
if re.match(pattern, qnn_input_names[j]):
847+
input_order_list.append(j)
848+
break
849+
assert len(input_order_list) == len(
850+
input_names
851+
), "Order list length is different from names"
852+
executorch_in_order[graph_name] = input_order_list
853+
qnn_in_order[graph_name] = sorted(
854+
range(len(input_order_list)), key=lambda k: input_order_list[k]
855+
)
856+
if output_nodes_dict:
857+
# output
858+
get_item_list = output_nodes_dict[graph_name][0].args[0]
859+
output_order_list = [item.args[1] for item in get_item_list]
860+
executorch_out_order[graph_name] = output_order_list
864861

865862
qnn_mgr.Destroy()
866863

@@ -869,15 +866,15 @@ def generate_multi_graph_program(
869866
bundle_progs = [
870867
from_context_binary(
871868
ctx_path=binary_info,
872-
op_name=f"loader_{graph_name}",
869+
op_name=f"loader_{graph_name}_{int(time.time())}",
873870
soc_model=compiler_options.soc_info.soc_model,
874871
custom_info={
875872
"graph_inputs": graph_inputs[graph_name],
876873
"graph_outputs": graph_outputs[graph_name],
877874
"graph_name": graph_name,
878-
"qnn_in_order": qnn_in_order[graph_name],
879-
"executorch_in_order": executorch_in_order[graph_name],
880-
"executorch_out_order": executorch_out_order[graph_name],
875+
"qnn_in_order": qnn_in_order.get(graph_name, None),
876+
"executorch_in_order": executorch_in_order.get(graph_name, None),
877+
"executorch_out_order": executorch_out_order.get(graph_name, None),
881878
},
882879
)
883880
for graph_name in graph_names
@@ -900,9 +897,101 @@ def generate_multi_graph_program(
900897
break
901898

902899
edge_prog_mgr = edge_prog_mgr.to_backend(QnnPartitioner(compiler_specs))
903-
return edge_prog_mgr.to_executorch(
900+
exec_prog = edge_prog_mgr.to_executorch(
901+
config=backend_config or ExecutorchBackendConfig()
902+
)
903+
return exec_prog, bundle_progs
904+
905+
906+
def generate_composite_llama_program(
907+
graph_names: List[str],
908+
sample_inputs_list: List[Tuple[Any]],
909+
lower_module_dict: Dict[str, List[LoweredBackendModule]],
910+
call_delegate_node_name_dict: Dict[str, List[str]],
911+
call_delegate_inputs_dict: Dict[str, List[Tuple[str, int | None]]],
912+
outputs_dict: Dict[str, List[Tuple[str, int]]],
913+
backend_config: ExecutorchBackendConfig = None,
914+
constant_methods: Optional[Dict[str, Any]] = None,
915+
) -> ExecutorchProgramManager:
916+
class CompositeLlamaModule(torch.nn.Module):
917+
def __init__(
918+
self,
919+
lower_module_list,
920+
call_delegate_node_name_list,
921+
call_delegate_inputs_list,
922+
outputs_list,
923+
) -> None:
924+
super().__init__()
925+
self.lower_module_list = lower_module_list
926+
self.call_delegate_node_name_list = call_delegate_node_name_list
927+
self.call_delegate_inputs_list = call_delegate_inputs_list
928+
self.outputs_list = outputs_list
929+
930+
def reorder(
931+
self,
932+
call_delegate_inputs: List[Tuple[str, int | None]],
933+
module_inputs: dict[str, torch.Tensor],
934+
all_ret: dict[str, torch.Tensor],
935+
) -> Tuple[torch.Tensor]:
936+
ret = []
937+
for name, index in call_delegate_inputs:
938+
if index is not None:
939+
# Get tensor from previous results
940+
ret.append(all_ret[name][index])
941+
else:
942+
# Get tensor from the inputs of module
943+
ret.append(module_inputs[name])
944+
return tuple(ret)
945+
946+
def forward(
947+
self,
948+
tokens: torch.Tensor,
949+
atten_mask: torch.Tensor,
950+
input_pos: Optional[torch.Tensor] = None,
951+
*args,
952+
) -> Tuple[torch.Tensor]:
953+
all_ret = {}
954+
module_input_dict = {
955+
"tokens": tokens,
956+
"atten_mask": atten_mask,
957+
"input_pos": input_pos,
958+
}
959+
for num, arg in enumerate(args):
960+
module_input_dict[f"args_{num}"] = arg
961+
for lower_module, call_delegate_node_name, call_delegate_inputs in zip(
962+
self.lower_module_list,
963+
self.call_delegate_node_name_list,
964+
self.call_delegate_inputs_list,
965+
):
966+
inp = self.reorder(call_delegate_inputs, module_input_dict, all_ret)
967+
ret = lower_module(*inp)
968+
all_ret[call_delegate_node_name] = ret
969+
llama_outputs = []
970+
for output_src_name, index in self.outputs_list:
971+
llama_outputs.append(all_ret[output_src_name][index])
972+
return tuple(llama_outputs)
973+
974+
progs_dict = {}
975+
for graph_name, sample_inputs in zip(graph_names, sample_inputs_list):
976+
composite_llama_module = CompositeLlamaModule(
977+
lower_module_dict[graph_name],
978+
call_delegate_node_name_dict[graph_name],
979+
call_delegate_inputs_dict[graph_name],
980+
outputs_dict[graph_name],
981+
)
982+
prog = torch.export.export(composite_llama_module, sample_inputs)
983+
progs_dict[graph_name] = prog
984+
# leverage ExecutorchProgramManager for generating pte with multi-methods
985+
edge_prog_mgr = to_edge(
986+
progs_dict,
987+
constant_methods=constant_methods,
988+
# do not alter name for custom op
989+
compile_config=EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=False),
990+
)
991+
exec_prog = edge_prog_mgr.to_executorch(
904992
config=backend_config or ExecutorchBackendConfig()
905993
)
994+
return exec_prog
906995

907996

908997
def generate_htp_compiler_spec(

0 commit comments

Comments
 (0)