Skip to content

Commit c50f1ba

Browse files
cccclaifacebook-github-bot
authored andcommitted
Expose mehod name as part of backend init context (#6622)
Summary: Provide the method name to backend so they can load the corresponding method name accordingly. The most immediate need is that the qnn context binary can include two methods, one for prefill and one for decode. Since we don't allow backend access multi methods at the moment, we do it in a hacky way via following ## AOT: ``` class LLama_transformer(): def prefill() def decode() ``` Then we will have two custom ops from two to_backends ops, and both will have two context binary ``` QAT (prefill) -> to_backend(...) => prefill.qcir flatbuffers QAT (decode) -> to_backend(...) => decode.qcir flatbuffers => graph prefill( custom_op_prefill() -> context_binary (two graphs) ) graph decode() custom_op_decode() -> context_binary (two graphs) ) ``` Since two context binary from these two customs ops will be exactly the same and they can be deduplicate during emit via these two lines https://github.com/pytorch/executorch/blob/d4a9ca01eb5bb786ecbfbcd8302253eb7797e8bb/exir/emit/_emitter.py#L136 and here https://github.com/pytorch/executorch/blob/d4a9ca01eb5bb786ecbfbcd8302253eb7797e8bb/exir/emit/_emitter.py#L1065-L1066 ``` .pte instrucions [ "prefill" [instructions: call_delegate(prefill_input)] "decode": [instructions: call_delegate(decode_input)] "delegate_payload:: Dict[bytes, index]) ] ``` ## Runtime After we expose the method name via this change, the backend can access the method name, and load the same method as the top level method ``` Result<DelegateHandle*> QNNBackend::init( BackendInitContext& context, FreeableBuffer* processed, ArrayRef<CompileSpec> compile_specs) { const char* method_name = context.get_method_name() // for example, "prefill" handle = qnn_backend.load(method_name) return handle } ``` This is to unblock sharing weight between prefill and decode for using htp backend. Differential Revision: D65386597
1 parent 85ec022 commit c50f1ba

File tree

4 files changed

+29
-3
lines changed

4 files changed

+29
-3
lines changed

runtime/backend/backend_init_context.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ namespace runtime {
1818
*/
1919
class BackendInitContext final {
2020
public:
21-
explicit BackendInitContext(MemoryAllocator* runtime_allocator)
22-
: runtime_allocator_(runtime_allocator) {}
21+
explicit BackendInitContext(MemoryAllocator* runtime_allocator, const char* method_name)
22+
: runtime_allocator_(runtime_allocator), method_name_(method_name) {}
2323

2424
/** Get the runtime allocator passed from Method. It's the same runtime
2525
* executor used by the standard executor runtime and the life span is the
@@ -28,9 +28,14 @@ class BackendInitContext final {
2828
MemoryAllocator* get_runtime_allocator() {
2929
return runtime_allocator_;
3030
}
31+
32+
const char* get_method_name() {
33+
return method_name_;
34+
}
3135

3236
private:
3337
MemoryAllocator* runtime_allocator_ = nullptr;
38+
const char* method_name_ = nullptr;
3439
};
3540

3641
} // namespace runtime

runtime/executor/method.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,7 @@ Error Method::init(executorch_flatbuffer::ExecutionPlan* s_plan) {
598598
init_state_ =
599599
InitializationState::InitializationFailed; // Until proven otherwise
600600
serialization_plan_ = s_plan;
601+
method_name_ = s_plan->name()->str().c_str();
601602
auto method_allocator = memory_manager_->method_allocator();
602603

603604
{
@@ -626,7 +627,7 @@ Error Method::init(executorch_flatbuffer::ExecutionPlan* s_plan) {
626627

627628
for (size_t i = 0; i < n_delegate; ++i) {
628629
const auto& delegate = *delegates->Get(i);
629-
BackendInitContext backend_init_context(method_allocator);
630+
BackendInitContext backend_init_context(method_allocator, method_name_);
630631
Error err = BackendDelegate::Init(
631632
delegate, program_, backend_init_context, &delegates_[i]);
632633
if (err != Error::Ok) {

runtime/executor/method.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ class Method final {
328328

329329
size_t n_chains_;
330330
Chain* chains_;
331+
const char* method_name_;
331332

332333
InitializationState init_state_;
333334

runtime/executor/test/backend_integration_test.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,25 @@ TEST_P(BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) {
528528
EXPECT_EQ(backend_load_was_called, using_segments());
529529
}
530530

531+
TEST_P(BackendIntegrationTest, GetMethodNameSuccess) {
532+
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
533+
ASSERT_EQ(loader.error(), Error::Ok);
534+
const void* processed_data = nullptr;
535+
StubBackend::singleton().install_init(
536+
[&](FreeableBuffer* processed,
537+
ET_UNUSED ArrayRef<CompileSpec> compile_specs,
538+
ET_UNUSED BackendInitContext backend_init_context)
539+
-> Result<DelegateHandle*> {
540+
auto method_name = backend_init_context.get_method_name();
541+
EXPECT_EQ(method_name, "forward");
542+
processed_data = processed->data();
543+
return nullptr;
544+
});
545+
Result<Program> program = Program::load(&loader.get());
546+
ASSERT_EQ(program.error(), Error::Ok);
547+
548+
}
549+
531550
// TODO: Add more tests for the runtime-to-backend interface. E.g.:
532551
// - Errors during init() or execute() result in runtime init/execution failures
533552
// - Correct values are passed to init()/execute()

0 commit comments

Comments
 (0)