Skip to content

Commit 9c0d9ab

Browse files
cccclaifacebook-github-bot
authored andcommitted
Expose method name as part of backend init context (#6622)
Summary: Provide the method name to backend so they can load the corresponding method name accordingly. The most immediate need is that the qnn context binary can include two methods, one for prefill and one for decode. Since we don't allow backend access multi methods at the moment, we do it in a hacky way via following ## AOT: ``` class LLama_transformer(): def prefill() def decode() ``` Then we will have two custom ops from two to_backends ops, and both will have two context binary ``` QAT (prefill) -> to_backend(...) => prefill.qcir flatbuffers QAT (decode) -> to_backend(...) => decode.qcir flatbuffers => graph prefill( custom_op_prefill() -> context_binary (two graphs) ) graph decode() custom_op_decode() -> context_binary (two graphs) ) ``` Since two context binary from these two customs ops will be exactly the same and they can be deduplicate during emit via these two lines https://github.com/pytorch/executorch/blob/d4a9ca01eb5bb786ecbfbcd8302253eb7797e8bb/exir/emit/_emitter.py#L136 and here https://github.com/pytorch/executorch/blob/d4a9ca01eb5bb786ecbfbcd8302253eb7797e8bb/exir/emit/_emitter.py#L1065-L1066 ``` .pte instrucions [ "prefill" [instructions: call_delegate(prefill_input)] "decode": [instructions: call_delegate(decode_input)] "delegate_payload:: Dict[bytes, index]) ] ``` ## Runtime After we expose the method name via this change, the backend can access the method name, and load the same method as the top level method ``` Result<DelegateHandle*> QNNBackend::init( BackendInitContext& context, FreeableBuffer* processed, ArrayRef<CompileSpec> compile_specs) { const char* method_name = context.get_method_name() // for example, "prefill" handle = qnn_backend.load(method_name) return handle } ``` This is to unblock sharing weight between prefill and decode for using htp backend. Reviewed By: kimishpatel, iseeyuan Differential Revision: D65386597
1 parent 545535b commit 9c0d9ab

File tree

5 files changed

+67
-9
lines changed

5 files changed

+67
-9
lines changed

runtime/backend/backend_execution_context.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ class BackendExecutionContext final {
2121
public:
2222
BackendExecutionContext(
2323
EventTracer* event_tracer = nullptr,
24-
MemoryAllocator* temp_allocator = nullptr)
25-
: event_tracer_(event_tracer), temp_allocator_(temp_allocator) {}
24+
MemoryAllocator* temp_allocator = nullptr,
25+
const char* method_name = nullptr)
26+
: event_tracer_(event_tracer), temp_allocator_(temp_allocator), method_name_(method_name) {}
2627

2728
/**
2829
* Returns a pointer to an instance of EventTracer to do profiling/debugging
@@ -52,9 +53,17 @@ class BackendExecutionContext final {
5253
return temp_allocator_;
5354
}
5455

56+
/**
57+
* Get the loaded method name from ExecuTorch runtime.
58+
*/
59+
const char* get_method_name() {
60+
return method_name_;
61+
}
62+
5563
private:
5664
EventTracer* event_tracer_ = nullptr;
5765
MemoryAllocator* temp_allocator_ = nullptr;
66+
const char* method_name_ = nullptr;
5867
};
5968

6069
} // namespace runtime

runtime/backend/backend_init_context.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ namespace runtime {
1818
*/
1919
class BackendInitContext final {
2020
public:
21-
explicit BackendInitContext(MemoryAllocator* runtime_allocator)
22-
: runtime_allocator_(runtime_allocator) {}
21+
explicit BackendInitContext(MemoryAllocator* runtime_allocator, const char* method_name = nullptr)
22+
: runtime_allocator_(runtime_allocator), method_name_(method_name) {}
2323

2424
/** Get the runtime allocator passed from Method. It's the same runtime
2525
* executor used by the standard executor runtime and the life span is the
@@ -28,9 +28,21 @@ class BackendInitContext final {
2828
MemoryAllocator* get_runtime_allocator() {
2929
return runtime_allocator_;
3030
}
31+
32+
/** Get the loaded method name from ExecuTorch runtime. Usually it's "forward",
33+
* however, if there are multiple methods in the .pte file, it can be different.
34+
* One example is that we may have prefill and decode methods in the same .pte file.
35+
* In this case, when client loads "prefill" method, the `get_method_name` function will
36+
* return "prefill", when client loads "decode" method, the `get_method_name` function will
37+
* return "decode".
38+
*/
39+
const char* get_method_name() {
40+
return method_name_;
41+
}
3142

3243
private:
3344
MemoryAllocator* runtime_allocator_ = nullptr;
45+
const char* method_name_ = nullptr;
3446
};
3547

3648
} // namespace runtime

runtime/executor/method.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,7 @@ Error Method::init(executorch_flatbuffer::ExecutionPlan* s_plan) {
598598
init_state_ =
599599
InitializationState::InitializationFailed; // Until proven otherwise
600600
serialization_plan_ = s_plan;
601+
method_name_ = s_plan->name()->str().c_str();
601602
auto method_allocator = memory_manager_->method_allocator();
602603

603604
{
@@ -626,7 +627,7 @@ Error Method::init(executorch_flatbuffer::ExecutionPlan* s_plan) {
626627

627628
for (size_t i = 0; i < n_delegate; ++i) {
628629
const auto& delegate = *delegates->Get(i);
629-
BackendInitContext backend_init_context(method_allocator);
630+
BackendInitContext backend_init_context(method_allocator, method_name_);
630631
Error err = BackendDelegate::Init(
631632
delegate, program_, backend_init_context, &delegates_[i]);
632633
if (err != Error::Ok) {
@@ -1098,7 +1099,8 @@ Error Method::execute_instruction() {
10981099
step_state_.instr_idx);
10991100
BackendExecutionContext backend_execution_context(
11001101
/*event_tracer*/ event_tracer_,
1101-
/*temp_allocator*/ temp_allocator_);
1102+
/*temp_allocator*/ temp_allocator_,
1103+
/*method_name_*/ method_name_);
11021104
err = delegates_[delegate_idx].Execute(
11031105
backend_execution_context,
11041106
chain.argument_lists_[step_state_.instr_idx].data());

runtime/executor/method.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ class Method final {
328328

329329
size_t n_chains_;
330330
Chain* chains_;
331+
const char* method_name_;
331332

332333
InitializationState init_state_;
333334

runtime/executor/test/backend_integration_test.cpp

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class StubBackend final : public BackendInterface {
5656
FreeableBuffer*,
5757
ArrayRef<CompileSpec>,
5858
BackendInitContext&)>;
59-
using ExecuteFn = std::function<Error(DelegateHandle*, EValue**)>;
59+
using ExecuteFn = std::function<Error(DelegateHandle*, EValue**, BackendExecutionContext&)>;
6060
using DestroyFn = std::function<void(DelegateHandle*)>;
6161

6262
// Default name that this backend is registered as.
@@ -98,7 +98,7 @@ class StubBackend final : public BackendInterface {
9898
DelegateHandle* handle,
9999
EValue** args) const override {
100100
if (execute_fn_) {
101-
return execute_fn_.value()(handle, args);
101+
return execute_fn_.value()(handle, args, context);
102102
}
103103
// Return a benign value otherwise.
104104
return Error::Ok;
@@ -404,7 +404,7 @@ TEST_P(BackendIntegrationTest, EndToEndTestWithProcessedAsHandle) {
404404
// FreeableBuffer.
405405
DelegateHandle* execute_handle = nullptr;
406406
StubBackend::singleton().install_execute(
407-
[&](DelegateHandle* handle, ET_UNUSED EValue** args) -> Error {
407+
[&](DelegateHandle* handle, ET_UNUSED EValue** args, ET_UNUSED BackendExecutionContext& context) -> Error {
408408
execute_handle = handle;
409409
auto* processed = reinterpret_cast<FreeableBuffer*>(handle);
410410

@@ -527,6 +527,40 @@ TEST_P(BackendIntegrationTest, SegmentInfoIsPassedIntoDataLoader) {
527527
EXPECT_EQ(backend_load_was_called, using_segments());
528528
}
529529

530+
TEST_P(BackendIntegrationTest, GetMethodNameDuringInitSuccess) {
531+
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
532+
ASSERT_EQ(loader.error(), Error::Ok);
533+
const void* processed_data = nullptr;
534+
StubBackend::singleton().install_init(
535+
[&](FreeableBuffer* processed,
536+
ET_UNUSED ArrayRef<CompileSpec> compile_specs,
537+
ET_UNUSED BackendInitContext backend_init_context)
538+
-> Result<DelegateHandle*> {
539+
auto method_name = backend_init_context.get_method_name();
540+
EXPECT_EQ(method_name, "forward");
541+
processed_data = processed->data();
542+
return nullptr;
543+
});
544+
Result<Program> program = Program::load(&loader.get());
545+
ASSERT_EQ(program.error(), Error::Ok);
546+
547+
}
548+
549+
TEST_P(BackendIntegrationTest, GetMethodNameDuringExecuteSuccess) {
550+
Result<FileDataLoader> loader = FileDataLoader::from(program_path());
551+
ASSERT_EQ(loader.error(), Error::Ok);
552+
DelegateHandle* execute_handle = nullptr;
553+
StubBackend::singleton().install_execute(
554+
[&](DelegateHandle* handle, ET_UNUSED EValue** args, ET_UNUSED BackendExecutionContext& backend_execution_context)-> Error {
555+
auto method_name = backend_execution_context.get_method_name();
556+
EXPECT_EQ(method_name, "forward");
557+
return Error::Ok;
558+
});
559+
Result<Program> program = Program::load(&loader.get());
560+
ASSERT_EQ(program.error(), Error::Ok);
561+
562+
}
563+
530564
// TODO: Add more tests for the runtime-to-backend interface. E.g.:
531565
// - Errors during init() or execute() result in runtime init/execution failures
532566
// - Correct values are passed to init()/execute()

0 commit comments

Comments
 (0)