Skip to content

Commit c7c0211

Browse files
committed
Qualcomm AI Engine Direct - enable multiple graphs in single pte
Summary: - support multiple graphs in single qnn context in runtime - helper function in aot for generating multi-methods pte - enable weight sharing mechanism on HTP - test cases
1 parent 11d1742 commit c7c0211

31 files changed

+784
-318
lines changed

backends/qualcomm/aot/ir/qcir.fbs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,13 @@ table Operator {
9494
}
9595

9696
table Graph {
97+
name: string;
9798
nodes: [Operator];
9899
tensors: [Tensor];
99100
}
100101

101-
root_type Graph;
102+
table Context {
103+
graphs: [Graph];
104+
}
105+
106+
root_type Context;

backends/qualcomm/aot/ir/qcir_utils.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,7 @@ flatbuffers::Offset<qcir::QuantizeParam> ToQuantizeParam(
161161
}
162162
} break;
163163
default:
164-
QNN_EXECUTORCH_LOG_WARN(
165-
"QNN_QUANTIZATION_ENCODING_UNDEFINED detected: %s",
166-
QNN_VER_PTR(tensor)->name);
164+
// encodings are not required if lowering with floating point precision
167165
break;
168166
}
169167
return CreateQuantizeParamDirect(
@@ -229,9 +227,7 @@ Qnn_QuantizeParams_t ToQuantizeParam(const tensor_type& tensor) {
229227
const_cast<int32_t*>(param->offsets()->data());
230228
} break;
231229
default:
232-
QNN_EXECUTORCH_LOG_WARN(
233-
"qcir::QuantizeType::UNDEFINED detected: %s",
234-
tensor->name()->c_str());
230+
// encodings are not required if lowering with floating point precision
235231
break;
236232
}
237233
return p;

backends/qualcomm/aot/python/PyQnnManagerAdaptor.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,22 @@ PYBIND11_MODULE(PyQnnManagerAdaptor, m) {
3030
py::class_<PyQnnManager, std::shared_ptr<PyQnnManager>>(m, "QnnManager")
3131
.def(py::init<const py::bytes&>())
3232
.def(py::init<const py::bytes&, const py::bytes&>())
33+
.def(py::init<const py::bytes&, const py::list&>())
3334
.def("Init", &PyQnnManager::Init)
3435
.def("IsNodeSupportedByBackend", &PyQnnManager::IsNodeSupportedByBackend)
35-
.def("Compile", &PyQnnManager::Compile)
36+
.def("Compile", py::overload_cast<>(&PyQnnManager::Compile))
37+
.def(
38+
"Compile",
39+
py::overload_cast<
40+
const std::string&,
41+
std::vector<std::shared_ptr<OpWrapper>>&>(&PyQnnManager::Compile))
3642
.def("Destroy", &PyQnnManager::Destroy)
3743
.def("IsAvailable", &PyQnnManager::IsAvailable)
3844
.def("IsTensorDump", &PyQnnManager::IsTensorDump)
3945
.def("AllocateTensor", &PyQnnManager::AllocateTensor)
4046
.def("GetGraphInputs", &PyQnnManager::GetGraphInputs)
4147
.def("GetGraphOutputs", &PyQnnManager::GetGraphOutputs)
48+
.def("GetGraphNames", &PyQnnManager::GetGraphNames)
4249
.def("GetSpillFillBufferSize", &PyQnnManager::GetSpillFillBufferSize);
4350
}
4451
} // namespace qnn

backends/qualcomm/aot/python/PyQnnManagerAdaptor.h

Lines changed: 108 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,100 @@ class PyQnnManager {
4747
qnn_manager_ = std::make_shared<QnnManager>(
4848
qnn_executorch_options, qnn_executorch_context_binary_);
4949
}
50+
// used for loading multiple graphs in qcir
51+
explicit PyQnnManager(const py::bytes& buffer, const py::list& qcirs)
52+
: qnn_executorch_option_ptr_(buffer) {
53+
auto qnn_executorch_options = GetQnnExecuTorchOptions(
54+
qnn_executorch_option_ptr_.cast<std::string_view>().data());
55+
56+
// merge multiple qcirs into one context with multiple graphs
57+
std::vector<flatbuffers::Offset<qcir::Graph>> graphs;
58+
for (size_t i = 0; i < qcirs.size(); ++i) {
59+
py::buffer_info info(py::buffer(qcirs[i].cast<py::bytes>()).request());
60+
flatbuffers::Verifier verifier(
61+
static_cast<const uint8_t* const>(info.ptr),
62+
info.size * info.itemsize);
63+
64+
if (!qcir::VerifyContextBuffer(verifier)) {
65+
QNN_EXECUTORCH_LOG_ERROR("Fail to verify qcir format");
66+
return;
67+
}
68+
auto context = qcir::GetContext(info.ptr);
69+
for (const auto& graph : *context->graphs()) {
70+
std::vector<flatbuffers::Offset<qcir::Tensor>> tensors;
71+
for (const auto tensor : *graph->tensors()) {
72+
// flatbuffers::Offset<Tensor> ToTensor(
73+
// QnnTensor
74+
// ToTensor(flatbuffers::Vector<::flatbuffers::Offset<qcir::Tensor>>
75+
// tensor), flatbuffers::FlatBufferBuilder* builder);
76+
tensors.emplace_back(ToTensor(ToTensor(tensor), &builder_));
77+
}
78+
std::vector<flatbuffers::Offset<qcir::Operator>> nodes;
79+
for (const auto& node : *graph->nodes()) {
80+
int32_t* inputs_ptr = const_cast<int32_t*>(node->inputs()->data());
81+
int32_t* outputs_ptr = const_cast<int32_t*>(node->outputs()->data());
82+
int32_t* params_ptr = const_cast<int32_t*>(node->params()->data());
83+
std::vector<int32_t> inputs(
84+
inputs_ptr, inputs_ptr + node->inputs()->size());
85+
std::vector<int32_t> outputs(
86+
outputs_ptr, outputs_ptr + node->outputs()->size());
87+
std::vector<int32_t> params(
88+
params_ptr, params_ptr + node->params()->size());
89+
nodes.emplace_back(qcir::CreateOperatorDirect(
90+
builder_,
91+
node->name()->str().c_str(),
92+
node->package_name()->str().c_str(),
93+
node->type_name()->str().c_str(),
94+
&inputs,
95+
&outputs,
96+
&params));
97+
}
98+
graphs.emplace_back(qcir::CreateGraphDirect(
99+
builder_, graph->name()->str().c_str(), &nodes, &tensors));
100+
}
101+
}
102+
auto context = qcir::CreateContextDirect(builder_, &graphs);
103+
builder_.Finish(context);
104+
qnn_executorch_context_binary_.buffer = builder_.GetBufferPointer();
105+
qnn_executorch_context_binary_.nbytes = builder_.GetSize();
106+
qnn_manager_ = std::make_shared<QnnManager>(
107+
qnn_executorch_options, qnn_executorch_context_binary_);
108+
}
50109

51110
executorch::runtime::Error Init() {
52111
return qnn_manager_->Init();
53112
}
113+
54114
bool IsNodeSupportedByBackend(
55115
std::vector<std::shared_ptr<OpWrapper>>& op_wrappers) {
56116
return qnn_manager_->IsNodeSupportedByBackend(op_wrappers);
57117
}
118+
119+
// this method is specific for compiling multi-graphs
120+
py::array_t<char> Compile() {
121+
if (qnn_manager_->CompileQcir() != Error::Ok) {
122+
QNN_EXECUTORCH_LOG_ERROR("Fail to compile qcir");
123+
return py::array_t<char>(0);
124+
}
125+
126+
// generate context binary if compilation succeded
127+
QnnExecuTorchContextBinary context_binary;
128+
qnn_manager_->GetContextBinary(context_binary);
129+
// allocate py::array (to pass the result of the C++ function to Python)
130+
auto result = py::array_t<char>(context_binary.nbytes);
131+
auto result_buffer = result.request();
132+
char* result_ptr = (char*)result_buffer.ptr;
133+
std::memcpy(result_ptr, context_binary.buffer, context_binary.nbytes);
134+
return result;
135+
}
136+
58137
py::array_t<char> Compile(
138+
const std::string& graph_name,
59139
std::vector<std::shared_ptr<OpWrapper>>& op_wrappers) {
60140
QnnExecuTorchContextBinary context_binary;
61141
flatbuffers::FlatBufferBuilder builder;
62142

63-
if (qnn_manager_->IsOnlinePrepare()) {
143+
if (qnn_manager_->IsOnlinePrepare() || qnn_manager_->IsMultipleGraphs()) {
64144
std::vector<flatbuffers::Offset<qcir::Tensor>> tensors;
65145
std::unordered_map<void*, int> tensor_map;
66146

@@ -126,14 +206,19 @@ class PyQnnManager {
126206
&outputs,
127207
&params));
128208
}
129-
auto graph = qcir::CreateGraphDirect(builder, &operators, &tensors);
130-
builder.Finish(graph);
209+
auto graph = qcir::CreateGraphDirect(
210+
builder, graph_name.c_str(), &operators, &tensors);
211+
std::vector<flatbuffers::Offset<qcir::Graph>> graphs({graph});
212+
auto context = qcir::CreateContextDirect(builder, &graphs);
213+
builder.Finish(context);
131214
context_binary.buffer = builder.GetBufferPointer();
132215
context_binary.nbytes = builder.GetSize();
133-
} else if (
134-
qnn_manager_->Compile(op_wrappers, context_binary) !=
135-
executorch::runtime::Error::Ok) {
136-
return py::array_t<char>(0);
216+
} else {
217+
if (qnn_manager_->Compile(graph_name, op_wrappers) !=
218+
executorch::runtime::Error::Ok) {
219+
return py::array_t<char>(0);
220+
}
221+
qnn_manager_->GetContextBinary(context_binary);
137222
}
138223

139224
// allocate py::array (to pass the result of the C++ function to
@@ -144,6 +229,7 @@ class PyQnnManager {
144229
std::memcpy(result_ptr, context_binary.buffer, context_binary.nbytes);
145230
return result;
146231
}
232+
147233
void Destroy() {
148234
return qnn_manager_->Destroy();
149235
}
@@ -156,28 +242,36 @@ class PyQnnManager {
156242
return qnn_manager_->IsTensorDump();
157243
}
158244

159-
executorch::runtime::Error AllocateTensor() {
160-
return qnn_manager_->AllocateTensor();
245+
executorch::runtime::Error AllocateTensor(const std::string& graph_name) {
246+
return qnn_manager_->AllocateTensor(graph_name);
161247
}
162248

163-
py::list GetGraphInputs() {
249+
py::list GetGraphInputs(const std::string& graph_name) {
164250
py::list ret;
165251
for (const std::shared_ptr<TensorWrapper>& input :
166-
qnn_manager_->GetGraphInputs()) {
252+
qnn_manager_->GetGraphInputs(graph_name)) {
167253
ret.append(PyQnnTensorWrapper(input));
168254
}
169255
return ret;
170256
}
171257

172-
py::list GetGraphOutputs() {
258+
py::list GetGraphOutputs(const std::string& graph_name) {
173259
py::list ret;
174260
for (const std::shared_ptr<TensorWrapper>& output :
175-
qnn_manager_->GetGraphOutputs()) {
261+
qnn_manager_->GetGraphOutputs(graph_name)) {
176262
ret.append(PyQnnTensorWrapper(output));
177263
}
178264
return ret;
179265
}
180266

267+
py::list GetGraphNames() {
268+
py::list ret;
269+
for (const std::string& graph_name : qnn_manager_->GetGraphNames()) {
270+
ret.append(graph_name);
271+
}
272+
return ret;
273+
}
274+
181275
uint64_t GetSpillFillBufferSize() {
182276
return qnn_manager_->GetSpillFillBufferSize();
183277
}
@@ -188,6 +282,7 @@ class PyQnnManager {
188282
const py::bytes qnn_executorch_option_ptr_;
189283
QnnExecuTorchContextBinary qnn_executorch_context_binary_;
190284
std::shared_ptr<QnnManager> qnn_manager_;
285+
flatbuffers::FlatBufferBuilder builder_;
191286
};
192287
} // namespace qnn
193288
} // namespace backends

backends/qualcomm/partition/qnn_partitioner.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManager
1111
import torch
1212
from executorch.backends.qualcomm.builders import node_visitor
13+
from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader
1314
from executorch.backends.qualcomm.qnn_preprocess import QnnBackend
1415
from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER
15-
from executorch.backends.qualcomm.utils.utils import generate_qnn_executorch_option
1616

1717
from executorch.exir.backend.backend_details import CompileSpec
1818
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
@@ -32,6 +32,7 @@
3232
not_supported_operator,
3333
to_be_implemented_operator,
3434
)
35+
from .utils import generate_qnn_executorch_option
3536

3637

3738
class QnnOperatorSupport(OperatorSupportBase):
@@ -63,7 +64,11 @@ def is_node_supported(self, _, node: torch.fx.Node) -> bool:
6364
)
6465
return False
6566

66-
if node.target in allow_list_operator:
67+
if (
68+
node.target in allow_list_operator
69+
# bypass if custom op appears
70+
or OpContextLoader.namespace == node.target.namespace
71+
):
6772
return True
6873

6974
if (

backends/qualcomm/partition/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import List
8+
9+
from executorch.backends.qualcomm.utils.constants import QCOM_QNN_COMPILE_SPEC
10+
11+
from executorch.exir.backend.compile_spec_schema import CompileSpec
12+
13+
14+
def generate_qnn_executorch_option(
15+
compiler_specs: List[CompileSpec],
16+
) -> bytes:
17+
for compiler_spec in compiler_specs:
18+
if compiler_spec.key == QCOM_QNN_COMPILE_SPEC:
19+
qnn_compile_spec_buffer = compiler_spec.value
20+
else:
21+
raise ValueError(f"unknown compiler spec key value: {compiler_spec.key}")
22+
return qnn_compile_spec_buffer

backends/qualcomm/qnn_preprocess.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from executorch.backends.qualcomm._passes.layout_transform import LayoutTransform
2020
from executorch.backends.qualcomm.builders.node_visitor import get_node_visitors
2121
from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader
22-
from executorch.backends.qualcomm.utils.utils import generate_qnn_executorch_option
22+
from executorch.backends.qualcomm.partition.utils import generate_qnn_executorch_option
2323
from executorch.exir.backend.backend_details import (
2424
BackendDetails,
2525
CompileSpec,
@@ -83,7 +83,7 @@ def preprocess(
8383
)
8484
try:
8585
context_loader_target = eval(
86-
f"torch.ops.{OpContextLoader.namespace}.{node.name}.default",
86+
f"torch.ops.{OpContextLoader.namespace}.{node.target.__name__}",
8787
globals().update(torch.__dict__),
8888
)
8989
assert node.target == context_loader_target, err_msg
@@ -104,7 +104,8 @@ def preprocess(
104104
else:
105105
raise RuntimeError(f"{node.op} is not supported in Qnn")
106106
qnn_context_binary = qnn_manager.Compile(
107-
[py_op_wrapper.GetOpWrapper() for py_op_wrapper in py_op_wrapper_list]
107+
qnn_manager.GetGraphNames()[0],
108+
[py_op_wrapper.GetOpWrapper() for py_op_wrapper in py_op_wrapper_list],
108109
)
109110
assert len(qnn_context_binary) != 0, "Failed to generate Qnn context binary."
110111
qnn_manager.Destroy()

0 commit comments

Comments
 (0)