Skip to content

Commit 22864ef

Browse files
committed
[6/N] Add update in module
- Expose the update API in module. Inside the module. invoke the method. Update API - Did a bit refactor for the `StubBackend` test code, such that it can shared between `backend_interface_update_test.cpp` and `module_test.cpp` Differential Revision: [D76172680](https://our.internmc.facebook.com/intern/diff/D76172680/) [ghstack-poisoned]
1 parent f1f105c commit 22864ef

File tree

11 files changed

+259
-126
lines changed

11 files changed

+259
-126
lines changed

extension/module/module.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,5 +308,13 @@ runtime::Error Module::set_output(
308308
output_tensor.mutable_data_ptr(), output_tensor.nbytes(), output_index);
309309
}
310310

311+
runtime::Error Module::update(
312+
const std::string& method_name,
313+
runtime::ArrayRef<runtime::Entry> backend_options) {
314+
ET_CHECK_OK_OR_RETURN_ERROR(load_method(method_name));
315+
auto& method = methods_.at(method_name).method;
316+
return method->update(backend_options);
317+
}
318+
311319
} // namespace extension
312320
} // namespace executorch

extension/module/module.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,32 @@ class Module {
457457
return set_output("forward", std::move(output_value), output_index);
458458
}
459459

460+
/**
461+
* EXPERIMENTAL: Updates backend options for a specific method.
462+
* Loads the program and method before updating if needed.
463+
*
464+
* @param[in] method_name The name of the method to update.
465+
* @param[in] backend_options The backend options to update the method with.
466+
*
467+
* @returns An Error to indicate success or failure.
468+
*/
469+
ET_EXPERIMENTAL ET_NODISCARD runtime::Error update(
470+
const std::string& method_name,
471+
runtime::ArrayRef<runtime::Entry> backend_options);
472+
473+
/**
474+
* EXPERIMENTAL: Updates backend options for the 'forward' method.
475+
* Loads the program and method before updating if needed.
476+
*
477+
* @param[in] backend_options The backend options to update the method with.
478+
*
479+
* @returns An Error to indicate success or failure.
480+
*/
481+
ET_EXPERIMENTAL ET_NODISCARD inline runtime::Error update(
482+
runtime::ArrayRef<runtime::Entry> backend_options) {
483+
return update("forward", backend_options);
484+
}
485+
460486
/**
461487
* Retrieves the EventTracer instance being used by the Module.
462488
* EventTracer is used for tracking and logging events during the execution

extension/module/test/module_test.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,32 @@
1616
#include <executorch/extension/data_loader/file_data_loader.h>
1717
#include <executorch/extension/tensor/tensor.h>
1818
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
19+
#include <executorch/runtime/backend/backend_options.h>
20+
#include <executorch/runtime/backend/backend_options_map.h>
21+
#include <executorch/runtime/executor/test/stub_backend.h>
1922

2023
using namespace ::executorch::extension;
2124
using namespace ::executorch::runtime;
25+
using executorch::runtime::BackendOptions;
26+
using executorch::runtime::Entry;
27+
using executorch::runtime::IntKey;
2228

2329
class ModuleTest : public ::testing::Test {
2430
protected:
2531
static void SetUpTestSuite() {
2632
model_path_ = std::getenv("ET_MODULE_ADD_PATH");
2733
add_mul_path_ = std::getenv("ET_MODULE_ADD_MUL_PROGRAM_PATH");
2834
add_mul_data_path_ = std::getenv("ET_MODULE_ADD_MUL_DATA_PATH");
35+
stub_model_path_ = std::getenv("ET_MODULE_ADD_MUL_DELEGATED_PATH");
36+
37+
// Register the StubBackend for testing
38+
StubBackend::register_singleton();
2939
}
3040

3141
static inline std::string model_path_;
3242
static inline std::string add_mul_path_;
3343
static inline std::string add_mul_data_path_;
44+
static inline std::string stub_model_path_;
3445
};
3546

3647
TEST_F(ModuleTest, TestLoad) {
@@ -466,3 +477,34 @@ TEST_F(ModuleTest, TestPTD) {
466477
auto tensor = make_tensor_ptr({2, 2}, {2.f, 3.f, 4.f, 2.f});
467478
ASSERT_EQ(module.forward(tensor).error(), Error::Ok);
468479
}
480+
481+
TEST_F(ModuleTest, TestUpdate) {
482+
Module module(stub_model_path_);
483+
484+
BackendOptionsMap<3> map;
485+
BackendOptions<1> backend_options;
486+
int new_num_threads = 4;
487+
backend_options.set_option(IntKey("NumberOfThreads"), new_num_threads);
488+
map.add("StubBackend", backend_options.view());
489+
490+
// Test update method with specific method name
491+
const auto update_result = module.update("forward", map.entries());
492+
EXPECT_EQ(update_result, Error::Ok);
493+
494+
ASSERT_EQ(StubBackend::singleton().num_threads(), new_num_threads);
495+
496+
}
497+
498+
TEST_F(ModuleTest, TestUpdateNonExistentMethod) {
499+
Module module(stub_model_path_);
500+
501+
BackendOptionsMap<3> map;
502+
BackendOptions<1> backend_options;
503+
int new_num_threads = 4;
504+
backend_options.set_option(IntKey("NumberOfThreads"), new_num_threads);
505+
map.add("StubBackend", backend_options.view());
506+
507+
// Test update method with non-existent method name
508+
const auto update_result = module.update("nonexistent", map.entries());
509+
EXPECT_NE(update_result, Error::Ok);
510+
}

extension/module/test/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def define_common_targets(is_fbcode=False):
1919
"ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])",
2020
"ET_MODULE_ADD_MUL_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.pte])",
2121
"ET_MODULE_ADD_MUL_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleAddMul.ptd])",
22+
"ET_MODULE_ADD_MUL_DELEGATED_PATH": "$(location fbcode//executorch/test/models:exported_delegated_add_mul[ModuleAddMul.pte])",
2223
}
2324

2425
for aten_mode in get_aten_mode_options():
@@ -35,6 +36,7 @@ def define_common_targets(is_fbcode=False):
3536
"//executorch/extension/module:module" + aten_suffix,
3637
"//executorch/extension/tensor:tensor" + aten_suffix,
3738
"//executorch/runtime/core/exec_aten/testing_util:tensor_util" + aten_suffix,
39+
"//executorch/runtime/executor/test:stub_backend",
3840
],
3941
env = modules_env,
4042
platforms = [CXX, ANDROID], # Cannot bundle resources on Apple platform.

runtime/backend/backend_options.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77
*/
88

99
#pragma once
10+
#include <executorch/runtime/core/array_ref.h>
1011
#include <executorch/runtime/core/error.h>
1112
#include <cstddef>
1213
#include <cstring>
13-
#include <executorch/runtime/core/error.h>
14-
#include <executorch/runtime/core/array_ref.h>
1514

1615
namespace executorch {
1716
namespace runtime {
@@ -125,7 +124,7 @@ class BackendOptions {
125124
executorch::runtime::ArrayRef<BackendOption> view() const {
126125
return executorch::runtime::ArrayRef<BackendOption>(options_, size_);
127126
}
128-
127+
129128
private:
130129
BackendOption options_[MaxCapacity]{};
131130
size_t size_;

runtime/executor/method.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1513,7 +1513,8 @@ Error Method::experimental_step() {
15131513
return step();
15141514
}
15151515

1516-
Error Method::update(executorch::runtime::ArrayRef<executorch::runtime::Entry> backend_option) {
1516+
Error Method::update(
1517+
executorch::runtime::ArrayRef<executorch::runtime::Entry> backend_option) {
15171518
for (const auto& entry : backend_option) {
15181519
const char* backend_name = entry.backend_name;
15191520
auto backend_options = entry.options;

runtime/executor/method.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,13 +241,14 @@ class Method final {
241241
/// DEPRECATED: Use `reset_execution()` instead.
242242
ET_DEPRECATED ET_NODISCARD Error experimental_reset_execution();
243243

244-
/**
244+
/**
245245
* EXPERIMENTAL: Update backend options, which will be dispatched to different backends.
246246
*
247247
* @retval Error::Ok step succeeded
248248
* @retval non-Ok Method update fails
249249
*/
250-
ET_EXPERIMENTAL ET_NODISCARD Error update(executorch::runtime::ArrayRef<executorch::runtime::Entry> backend_option);
250+
ET_EXPERIMENTAL ET_NODISCARD Error update(
251+
executorch::runtime::ArrayRef<executorch::runtime::Entry> backend_option);
251252

252253
/**
253254
* Returns the MethodMeta that corresponds to the calling Method.

runtime/executor/test/method_update_test.cpp

Lines changed: 41 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -10,140 +10,61 @@
1010
#include <filesystem>
1111

1212
#include <executorch/extension/data_loader/file_data_loader.h>
13+
#include <executorch/runtime/backend/backend_options.h>
14+
#include <executorch/runtime/backend/backend_options_map.h>
15+
#include <executorch/runtime/backend/backend_update_context.h>
16+
#include <executorch/runtime/backend/interface.h>
17+
#include <executorch/runtime/core/error.h>
1318
#include <executorch/runtime/core/exec_aten/exec_aten.h>
19+
#include <executorch/runtime/core/result.h>
1420
#include <executorch/runtime/executor/method.h>
1521
#include <executorch/runtime/executor/program.h>
1622
#include <executorch/runtime/executor/test/managed_memory_manager.h>
23+
#include <executorch/runtime/executor/test/stub_backend.h>
1724
#include <executorch/runtime/platform/runtime.h>
1825
#include <executorch/test/utils/DeathTest.h>
1926
#include <gtest/gtest.h>
20-
#include <executorch/runtime/backend/interface.h>
21-
#include <executorch/runtime/backend/backend_update_context.h>
22-
#include <executorch/runtime/backend/backend_options.h>
23-
#include <executorch/runtime/backend/backend_options_map.h>
24-
#include <executorch/runtime/core/error.h>
25-
#include <executorch/runtime/core/result.h>
26-
2727

2828
using namespace ::testing;
2929
using executorch::aten::ArrayRef;
30+
using executorch::runtime::BackendExecutionContext;
31+
using executorch::runtime::BackendInitContext;
32+
using executorch::runtime::BackendInterface;
33+
using executorch::runtime::BackendOption;
34+
using executorch::runtime::BackendOptions;
35+
using executorch::runtime::BackendOptionsMap;
36+
using executorch::runtime::BackendUpdateContext;
37+
using executorch::runtime::BoolKey;
38+
using executorch::runtime::CompileSpec;
39+
using executorch::runtime::DataLoader;
40+
using executorch::runtime::DelegateHandle;
41+
using executorch::runtime::Entry;
3042
using executorch::runtime::Error;
3143
using executorch::runtime::EValue;
44+
using executorch::runtime::FreeableBuffer;
45+
using executorch::runtime::IntKey;
3246
using executorch::runtime::Method;
47+
using executorch::runtime::OptionType;
3348
using executorch::runtime::Program;
3449
using executorch::runtime::Result;
3550
using executorch::runtime::testing::ManagedMemoryManager;
3651
using torch::executor::util::FileDataLoader;
37-
using executorch::runtime::BackendExecutionContext;
38-
using executorch::runtime::BackendInitContext;
39-
using executorch::runtime::BackendInterface;
40-
using executorch::runtime::BackendUpdateContext;
41-
using executorch::runtime::BackendOption;
42-
using executorch::runtime::BackendOptions;
43-
using executorch::runtime::BackendOptionsMap;
44-
using executorch::runtime::BoolKey;
45-
using executorch::runtime::IntKey;
46-
using executorch::runtime::Entry;
47-
using executorch::runtime::OptionType;
48-
using executorch::runtime::CompileSpec;
49-
using executorch::runtime::DataLoader;
50-
using executorch::runtime::DelegateHandle;
51-
using executorch::runtime::FreeableBuffer;
52-
52+
5353
constexpr size_t kDefaultNonConstMemBytes = 32 * 1024U;
5454
constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024U;
5555

56-
/**
57-
* A backend class whose methods can be overridden individually.
58-
*/
59-
class StubBackend final : public BackendInterface {
60-
public:
61-
62-
// Default name that this backend is registered as.
63-
static constexpr char kName[] = "StubBackend";
64-
65-
bool is_available() const override {
66-
return true;
67-
}
68-
69-
Result<DelegateHandle*> init(
70-
BackendInitContext& context,
71-
FreeableBuffer* processed,
72-
ArrayRef<CompileSpec> compile_specs) const override {
73-
return nullptr;
74-
}
75-
76-
Error execute(
77-
BackendExecutionContext& context,
78-
DelegateHandle* handle,
79-
EValue** args) const override {
80-
return Error::Ok;
81-
}
82-
83-
int num_threads() const {
84-
return num_threads_;
85-
}
86-
87-
Error update(
88-
BackendUpdateContext& context,
89-
const executorch::runtime::ArrayRef<BackendOption>& backend_options) const override {
90-
int success_update = 0;
91-
for (const auto& backend_option : backend_options) {
92-
if (strcmp(backend_option.key, "NumberOfThreads") == 0) {
93-
if (backend_option.type == OptionType::INT) {
94-
num_threads_ = backend_option.value.int_value;
95-
success_update++;
96-
}
97-
}
98-
}
99-
if (success_update == backend_options.size()) {
100-
return Error::Ok;
101-
}
102-
return Error::InvalidArgument;
103-
}
104-
105-
/**
106-
* Registers the singleton instance if not already registered.
107-
*
108-
* Note that this can be used to install the stub as the implementation for
109-
* any export-time backend by passing in the right name, as long as no other
110-
* backend with that name has been registered yet.
111-
*/
112-
static Error register_singleton(const char* name = kName) {
113-
if (!registered_) {
114-
registered_ = true;
115-
return executorch::runtime::register_backend({name, &singleton_});
116-
}
117-
return Error::Ok;
118-
}
119-
120-
/**
121-
* Returns the instance that was added to the backend registry.
122-
*/
123-
static StubBackend& singleton() {
124-
return singleton_;
125-
}
126-
127-
private:
128-
static bool registered_;
129-
static StubBackend singleton_;
130-
mutable int num_threads_ = 1;
131-
};
132-
133-
bool StubBackend::registered_ = false;
134-
StubBackend StubBackend::singleton_;
135-
13656
class MethodUpdateTest : public ::testing::Test {
13757
protected:
13858
void load_program() {
139-
// Since these tests cause ET_LOG to be called, the PAL must be initialized
140-
// first.
141-
executorch::runtime::runtime_init();
142-
143-
// Create a loader for the serialized program.
59+
// Since these tests cause ET_LOG to be called, the PAL must be initialized
60+
// first.
61+
executorch::runtime::runtime_init();
62+
63+
// Create a loader for the serialized program.
14464
ASSERT_EQ(StubBackend::register_singleton(), Error::Ok);
145-
146-
auto loader_res = FileDataLoader::from(std::getenv("ET_MODULE_ADD_MUL_DELEGATED_PATH"));
65+
66+
auto loader_res =
67+
FileDataLoader::from(std::getenv("ET_MODULE_ADD_MUL_DELEGATED_PATH"));
14768
ASSERT_EQ(loader_res.error(), Error::Ok);
14869
loader_ = std::make_unique<FileDataLoader>(std::move(loader_res.get()));
14970

@@ -152,7 +73,7 @@ class StubBackend final : public BackendInterface {
15273
ASSERT_EQ(program_res.error(), Error::Ok);
15374
program_ = std::make_unique<Program>(std::move(program_res.get()));
15475
}
155-
76+
15677
void SetUp() override {
15778
executorch::runtime::runtime_init();
15879

@@ -161,22 +82,22 @@ class StubBackend final : public BackendInterface {
16182

16283
private:
16384
std::unique_ptr<FileDataLoader> loader_;
164-
85+
16586
protected:
16687
std::unique_ptr<Program> program_;
16788
};
16889

16990
TEST_F(MethodUpdateTest, MoveTest) {
170-
BackendInterface* backend =
171-
executorch::runtime::get_backend_class(StubBackend::kName);
172-
ASSERT_EQ(backend, &StubBackend::singleton());
173-
91+
BackendInterface* backend =
92+
executorch::runtime::get_backend_class(StubBackend::kName);
93+
ASSERT_EQ(backend, &StubBackend::singleton());
94+
17495
ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
17596
Result<Method> method = program_->load_method("forward", &mmm.get());
176-
// Check that the default number of threads is 1.
97+
// Check that the default number of threads is 1.
17798
ASSERT_EQ(StubBackend::singleton().num_threads(), 1);
17899
ASSERT_EQ(method.error(), Error::Ok);
179-
100+
180101
BackendOptionsMap<3> map;
181102
BackendOptions<1> backend_options;
182103
int new_num_threads = 4;
@@ -185,4 +106,5 @@ class StubBackend final : public BackendInterface {
185106
Error update_result = method->update(map.entries());
186107
ASSERT_EQ(update_result, Error::Ok);
187108
ASSERT_EQ(StubBackend::singleton().num_threads(), new_num_threads);
188-
}
109+
}
110+

0 commit comments

Comments
 (0)