Skip to content

Commit b73fb1e

Browse files
pytorchbotSheng Feng Wu
andauthored
Qualcomm AI Engine Direct - Refine max spill fill buffer setting (#6041)
Qualcomm AI Engine Direct - Refine max spill fill buffer setting (#5989) Summary: - Get required spillFillBufferSize from context binary and set to compiler_spec - Quantize embedding op in qnn. - If enable multi-contexts, maxSpillFillBuffer could not set to zero. Pull Request resolved: #5989 Reviewed By: kirklandsign Differential Revision: D64056107 Pulled By: cccclai fbshipit-source-id: 9f9846e6ac7b4a27d734d2812ac3bbad32fb194f (cherry picked from commit 01fcdf4) Co-authored-by: Sheng Feng Wu <[email protected]>
1 parent 0a3002f commit b73fb1e

18 files changed

+163
-39
lines changed

backends/qualcomm/aot/python/PyQnnManagerAdaptor.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ PYBIND11_MODULE(PyQnnManagerAdaptor, m) {
3535
.def("IsTensorDump", &PyQnnManager::IsTensorDump)
3636
.def("AllocateTensor", &PyQnnManager::AllocateTensor)
3737
.def("GetGraphInputs", &PyQnnManager::GetGraphInputs)
38-
.def("GetGraphOutputs", &PyQnnManager::GetGraphOutputs);
38+
.def("GetGraphOutputs", &PyQnnManager::GetGraphOutputs)
39+
.def("GetSpillFillBufferSize", &PyQnnManager::GetSpillFillBufferSize);
3940
}
4041
} // namespace qnn
4142
} // namespace executor

backends/qualcomm/aot/python/PyQnnManagerAdaptor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ class PyQnnManager {
177177
return ret;
178178
}
179179

180+
uint64_t GetSpillFillBufferSize() {
181+
return qnn_manager_->GetSpillFillBufferSize();
182+
}
183+
180184
private:
181185
// Store the bytes object instead of a raw pointer so that this module will
182186
// keep the bytes alive.

backends/qualcomm/runtime/QnnManager.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,10 @@ Error QnnManager::Init() {
283283
qnn_loaded_backend_, logger_.get(), qnn_context_blob_, options_);
284284
ET_CHECK_OR_RETURN_ERROR(
285285
backend_params_ptr_ != nullptr, Internal, "Failed to load Qnn backend.")
286+
ET_CHECK_OR_RETURN_ERROR(
287+
backend_params_ptr_->qnn_backend_cache_ptr_->Configure() == Error::Ok,
288+
Internal,
289+
"Fail to configure Qnn backend cache");
286290
ET_CHECK_OR_RETURN_ERROR(
287291
backend_params_ptr_->qnn_backend_ptr_->Configure() == Error::Ok,
288292
Internal,

backends/qualcomm/runtime/QnnManager.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ class QnnManager {
7070
// Pre-register custom memory handle from the SharedBuffer before execution
7171
Error PreRegisterMem();
7272

73+
uint64_t GetSpillFillBufferSize() {
74+
auto* htp_backend_cache_ptr = static_cast<HtpBackendCache*>(
75+
backend_params_ptr_->qnn_backend_cache_ptr_.get());
76+
return htp_backend_cache_ptr->GetSpillFillBufferSize();
77+
}
78+
7379
std::vector<std::shared_ptr<TensorWrapper>> GetGraphInputs() {
7480
return input_tensors_;
7581
}

backends/qualcomm/runtime/backends/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ target_sources(
7777
target_sources(
7878
qnn_backend_cache
7979
PUBLIC ${CMAKE_CURRENT_LIST_DIR}/QnnBackendCache.h
80+
${CMAKE_CURRENT_LIST_DIR}/htpbackend/HtpBackendCache.h
8081
PRIVATE ${CMAKE_CURRENT_LIST_DIR}/QnnBackendCache.cpp
82+
${CMAKE_CURRENT_LIST_DIR}/htpbackend/HtpBackendCache.cpp
8183
)
8284

8385
# qnn_graph
@@ -130,6 +132,7 @@ set(qnn_header_basenames
130132
HTP/QnnHtpPerfInfrastructure.h
131133
HTP/QnnHtpProfile.h
132134
HTP/QnnHtpProperty.h
135+
HTP/QnnHtpSystemContext.h
133136
QnnInterface.h
134137
QnnLog.h
135138
QnnMem.h

backends/qualcomm/runtime/backends/QnnBackendCache.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,20 @@ Error QnnBackendCache::GetQnnGraphInfoFromBinary() {
2828

2929
if (error != QNN_SUCCESS) {
3030
QNN_EXECUTORCH_LOG_WARN(
31-
"Failed to interpret QNN Context "
31+
"Failed to interpret QNN context "
3232
"binary. Error code %d. "
3333
"Try verifying binary with online-prepare format.",
3434
QNN_GET_ERROR_CODE(error));
3535
return Error::Internal;
3636
}
3737

38+
Error status = RetrieveBackendBinaryInfo(binaryinfo);
39+
if (status == Error::Internal) {
40+
QNN_EXECUTORCH_LOG_ERROR(
41+
"Failed to retrieve backend binary info from QNN context binary.");
42+
return Error::Internal;
43+
}
44+
3845
if (binaryinfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) {
3946
num_graphs = binaryinfo->contextBinaryInfoV1.numGraphs;
4047
graph = binaryinfo->contextBinaryInfoV1.graphs;
@@ -81,20 +88,18 @@ Error QnnBackendCache::GetQnnGraphInfoFromBinary() {
8188
return Error::Ok;
8289
}
8390

84-
QnnBackendCache::QnnBackendCache(
85-
const QnnExecuTorchContextBinary& qnn_context_blob)
86-
: qnn_context_blob_(qnn_context_blob) {
91+
Error QnnBackendCache::Configure() {
8792
if (qnn_context_blob_.buffer == nullptr) {
8893
state_ = SERIALIZE;
8994
QNN_EXECUTORCH_LOG_INFO("Caching: Caching is in SAVE MODE.");
90-
return;
95+
return Error::Ok;
9196
}
9297

9398
if (qnn_sys_impl_.Load() != Error::Ok) {
9499
QNN_EXECUTORCH_LOG_ERROR(
95100
"Failed to Load QnnSystem "
96101
"APIs. Caching mechanism is being disabled.");
97-
return;
102+
return Error::Internal;
98103
}
99104

100105
Qnn_ErrorHandle_t error = QNN_SUCCESS;
@@ -109,7 +114,7 @@ QnnBackendCache::QnnBackendCache(
109114
"Failed to create Qnn "
110115
"SystemContext. Caching mechanism will be disabled. Error code %d",
111116
QNN_GET_ERROR_CODE(error));
112-
return;
117+
return Error::Internal;
113118
}
114119

115120
// DO DESERIALIZE
@@ -125,16 +130,16 @@ QnnBackendCache::QnnBackendCache(
125130

126131
if (qcir::VerifyGraphBuffer(verifier)) {
127132
state_ = ONLINE_PREPARE;
128-
return;
133+
return Error::Ok;
129134
}
130135

131136
QNN_EXECUTORCH_LOG_ERROR(
132137
"Failed to parse QNN Graph Info. The cache "
133138
"might be broken. Please consider to re-generate the "
134139
"cache.");
135140
InvalidateCache();
136-
return;
137141
}
142+
return Error::Ok;
138143
}
139144

140145
QnnBackendCache::~QnnBackendCache() {

backends/qualcomm/runtime/backends/QnnBackendCache.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ class QnnBackendCache {
2323
DESERIALIZE = 2,
2424
ONLINE_PREPARE = 3,
2525
};
26-
explicit QnnBackendCache(const QnnExecuTorchContextBinary& qnn_context_blob);
27-
28-
~QnnBackendCache();
26+
explicit QnnBackendCache(const QnnExecuTorchContextBinary& qnn_context_blob)
27+
: qnn_context_blob_(qnn_context_blob) {}
28+
virtual ~QnnBackendCache();
2929
QnnBackendCache(const QnnBackendCache&) = delete;
3030
QnnBackendCache(QnnBackendCache&&) = delete;
3131
QnnBackendCache& operator=(const QnnBackendCache&) = delete;
@@ -51,6 +51,14 @@ class QnnBackendCache {
5151
return graph_name_;
5252
}
5353

54+
Error Configure();
55+
56+
protected:
57+
virtual Error RetrieveBackendBinaryInfo(
58+
__ET_UNUSED const QnnSystemContext_BinaryInfo_t* binaryinfo) {
59+
return Error::Ok;
60+
}
61+
5462
private:
5563
Error GetQnnGraphInfoFromBinary();
5664

backends/qualcomm/runtime/backends/QnnBackendFactory.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,14 @@ std::unique_ptr<BackendConfigParameters> QnnBackendFactory::Create(
5656
backend_params->qnn_device_ptr_ = std::make_unique<HtpDevice>(
5757
implementation, logger, options->soc_info(), htp_options);
5858

59+
backend_params->qnn_backend_cache_ptr_ =
60+
std::make_unique<HtpBackendCache>(qnn_context_blob);
61+
5962
backend_params->qnn_context_ptr_ = std::make_unique<HtpContext>(
6063
implementation,
6164
backend_params->qnn_backend_ptr_.get(),
6265
backend_params->qnn_device_ptr_.get(),
63-
qnn_context_blob,
66+
backend_params->qnn_backend_cache_ptr_.get(),
6467
htp_options);
6568

6669
backend_params->qnn_graph_ptr_ = std::make_unique<HtpGraph>(

backends/qualcomm/runtime/backends/QnnBackendFactory.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#pragma once
99

1010
#include <executorch/backends/qualcomm/runtime/QnnExecuTorch.h>
11+
#include <executorch/backends/qualcomm/runtime/backends/QnnBackendCache.h>
1112
#include <executorch/backends/qualcomm/runtime/backends/QnnBackendCommon.h>
1213
#include <executorch/backends/qualcomm/runtime/backends/QnnContextCommon.h>
1314
#include <executorch/backends/qualcomm/runtime/backends/QnnDeviceCommon.h>
@@ -16,6 +17,7 @@
1617
#include <executorch/backends/qualcomm/runtime/backends/QnnLogger.h>
1718
#include <executorch/backends/qualcomm/runtime/backends/QnnMemManager.h>
1819
#include <executorch/backends/qualcomm/runtime/backends/htpbackend/HtpBackend.h>
20+
#include <executorch/backends/qualcomm/runtime/backends/htpbackend/HtpBackendCache.h>
1921
#include <executorch/backends/qualcomm/runtime/backends/htpbackend/HtpContext.h>
2022
#include <executorch/backends/qualcomm/runtime/backends/htpbackend/HtpDevice.h>
2123
#include <executorch/backends/qualcomm/runtime/backends/htpbackend/HtpGraph.h>
@@ -35,6 +37,7 @@ typedef struct BackendConfigParameters {
3537
std::unique_ptr<QnnDevice> qnn_device_ptr_;
3638
std::unique_ptr<QnnGraph> qnn_graph_ptr_;
3739
std::unique_ptr<QnnMemManager> qnn_mem_manager_ptr_;
40+
std::unique_ptr<QnnBackendCache> qnn_backend_cache_ptr_;
3841

3942
// Default ctor
4043
BackendConfigParameters()
@@ -43,10 +46,12 @@ typedef struct BackendConfigParameters {
4346
qnn_context_ptr_(nullptr),
4447
qnn_device_ptr_(nullptr),
4548
qnn_graph_ptr_(nullptr),
46-
qnn_mem_manager_ptr_(nullptr) {}
49+
qnn_mem_manager_ptr_(nullptr),
50+
qnn_backend_cache_ptr_(nullptr) {}
4751
// Default dtor
4852
~BackendConfigParameters() {
4953
qnn_graph_ptr_.reset();
54+
qnn_backend_cache_ptr_.reset();
5055
qnn_mem_manager_ptr_.reset();
5156
qnn_context_ptr_.reset();
5257
qnn_device_ptr_.reset();

backends/qualcomm/runtime/backends/QnnContextCommon.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,12 @@ class QnnContext {
2222
const QnnImplementation& implementation,
2323
QnnBackend* backend,
2424
QnnDevice* device,
25-
const QnnExecuTorchContextBinary& qnn_context_blob)
25+
QnnBackendCache* cache)
2626
: handle_(nullptr),
2727
implementation_(implementation),
2828
backend_(backend),
29-
device_(device) {
30-
cache_ = std::make_unique<QnnBackendCache>(qnn_context_blob);
31-
}
29+
device_(device),
30+
cache_(cache) {}
3231

3332
virtual ~QnnContext();
3433
Error Configure();
@@ -67,7 +66,7 @@ class QnnContext {
6766
const QnnImplementation& implementation_;
6867
QnnBackend* backend_;
6968
QnnDevice* device_;
70-
std::unique_ptr<QnnBackendCache> cache_;
69+
QnnBackendCache* cache_;
7170
std::vector<char> binary_buffer_;
7271
};
7372
} // namespace qnn
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright (c) Qualcomm Innovation Center, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#include <executorch/backends/qualcomm/runtime/backends/htpbackend/HtpBackendCache.h>
9+
#include "HTP/QnnHtpSystemContext.h"
10+
11+
namespace torch {
12+
namespace executor {
13+
namespace qnn {
14+
Error HtpBackendCache::RetrieveBackendBinaryInfo(
15+
const QnnSystemContext_BinaryInfo_t* binaryinfo) {
16+
QnnHtpSystemContext_HwBlobInfo_t* htp_hwblobinfo = nullptr;
17+
18+
if (binaryinfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) {
19+
htp_hwblobinfo = static_cast<QnnHtpSystemContext_HwBlobInfo_t*>(
20+
binaryinfo->contextBinaryInfoV1.hwInfoBlob);
21+
} else if (binaryinfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) {
22+
htp_hwblobinfo = static_cast<QnnHtpSystemContext_HwBlobInfo_t*>(
23+
binaryinfo->contextBinaryInfoV2.hwInfoBlob);
24+
} else {
25+
QNN_EXECUTORCH_LOG_WARN(
26+
"Unknown QNN BinaryInfo version %d.", binaryinfo->version);
27+
return Error::Internal;
28+
}
29+
30+
if (htp_hwblobinfo == nullptr) {
31+
QNN_EXECUTORCH_LOG_WARN(
32+
"Htp hardware blob information is not found in binary information.");
33+
return Error::Ok;
34+
}
35+
36+
if (htp_hwblobinfo->version ==
37+
QNN_SYSTEM_CONTEXT_HTP_HW_INFO_BLOB_VERSION_V1) {
38+
spill_fill_buf_ =
39+
(*htp_hwblobinfo).contextBinaryHwInfoBlobV1_t.spillFillBufferSize;
40+
} else {
41+
QNN_EXECUTORCH_LOG_WARN(
42+
"Unknown QNN Htp hw blob info version %d.", htp_hwblobinfo->version);
43+
return Error::Internal;
44+
}
45+
46+
return Error::Ok;
47+
}
48+
49+
} // namespace qnn
50+
} // namespace executor
51+
} // namespace torch
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright (c) Qualcomm Innovation Center, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
#pragma once
9+
#include <executorch/backends/qualcomm/runtime/backends/QnnBackendCache.h>
10+
11+
namespace torch {
12+
namespace executor {
13+
namespace qnn {
14+
class HtpBackendCache : public QnnBackendCache {
15+
public:
16+
explicit HtpBackendCache(const QnnExecuTorchContextBinary& qnn_context_blob)
17+
: QnnBackendCache(qnn_context_blob), spill_fill_buf_(0) {}
18+
~HtpBackendCache() override = default;
19+
20+
uint64_t GetSpillFillBufferSize() {
21+
return spill_fill_buf_;
22+
}
23+
24+
protected:
25+
Error RetrieveBackendBinaryInfo(
26+
const QnnSystemContext_BinaryInfo_t* binaryinfo) override;
27+
28+
private:
29+
uint64_t spill_fill_buf_;
30+
};
31+
} // namespace qnn
32+
} // namespace executor
33+
} // namespace torch

backends/qualcomm/runtime/backends/htpbackend/HtpContext.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ class HtpContext : public QnnContext {
2222
const QnnImplementation& implementation,
2323
QnnBackend* backend,
2424
QnnDevice* device,
25-
const QnnExecuTorchContextBinary& qnn_context_blob,
25+
QnnBackendCache* cache,
2626
const QnnExecuTorchHtpBackendOptions* htp_options)
27-
: QnnContext(implementation, backend, device, qnn_context_blob) {
27+
: QnnContext(implementation, backend, device, cache) {
2828
htp_context_custom_config_ =
2929
std::make_unique<HtpContextCustomConfig>(this, htp_options);
3030
}

backends/qualcomm/runtime/backends/htpbackend/aarch64/HtpContextCustomConfig.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ HtpContextCustomConfig::CreateContextCustomConfig() {
1919
QnnHtpContext_CustomConfig_t* p_custom_config = nullptr;
2020
const HtpContext* htp_ctx = static_cast<const HtpContext*>(context_);
2121

22-
if (htp_options_->use_multi_contexts()) {
22+
if (htp_options_->use_multi_contexts() &&
23+
htp_options_->max_sf_buf_size() != 0) {
2324
p_custom_config = AllocContextCustomConfig();
2425
p_custom_config->option =
2526
QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS;

backends/qualcomm/utils/utils.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,16 +208,28 @@ def process_exported_program(prog):
208208
== QnnExecuTorchBackendType.kHtpBackend
209209
and options.backend_options.htp_options.use_multi_contexts
210210
):
211-
max_sf_buf_size = max(max_sf_buf_size, len(m.processed_bytes))
211+
qnn_mgr = PyQnnManagerAdaptor.QnnManager(
212+
m.compile_specs[0].value, m.processed_bytes
213+
)
214+
assert qnn_mgr.Init().value == 0, "failed to load context binary"
215+
max_sf_buf_size = max(
216+
max_sf_buf_size, qnn_mgr.GetSpillFillBufferSize()
217+
)
212218
module_map[m] = options
219+
qnn_mgr.Destroy()
213220
return max_sf_buf_size, module_map
214221

215222
def process_lowered_module(module):
223+
qnn_mgr = PyQnnManagerAdaptor.QnnManager(
224+
module.compile_specs[0].value, module.processed_bytes
225+
)
226+
assert qnn_mgr.Init().value == 0, "failed to load context binary"
216227
spill_fill_size = (
217-
len(module.processed_bytes)
228+
qnn_mgr.GetSpillFillBufferSize()
218229
if custom_buffer_size is None
219230
else custom_buffer_size
220231
)
232+
qnn_mgr.Destroy()
221233
return spill_fill_size, {
222234
module: convert_to_option(module.compile_specs[0].value)
223235
}

0 commit comments

Comments
 (0)