Skip to content

Commit fafff0c

Browse files
pytorchbotcccclai
andauthored
[5.1/ N] set_option/get_option API with {backend_name, backend options} only (#11877)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #11865 by @cccclai ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/cccclai/31/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/cccclai/31/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/cccclai/23/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/cccclai/31/orig @diff-train-skip-merge --------- Co-authored-by: Chen Lai <[email protected]> Co-authored-by: cccclai <[email protected]>
1 parent 44ab83a commit fafff0c

File tree

3 files changed

+180
-1
lines changed

3 files changed

+180
-1
lines changed

runtime/backend/interface.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,42 @@ Result<const char*> get_backend_name(size_t index) {
6666
return registered_backends[index].name;
6767
}
6868

69+
Error set_option(
70+
const char* backend_name,
71+
const executorch::runtime::Span<executorch::runtime::BackendOption>
72+
backend_options) {
73+
auto backend_class = get_backend_class(backend_name);
74+
if (!backend_class) {
75+
return Error::NotFound;
76+
}
77+
78+
BackendOptionContext backend_option_context;
79+
Error result =
80+
backend_class->set_option(backend_option_context, backend_options);
81+
if (result != Error::Ok) {
82+
return result;
83+
}
84+
return Error::Ok;
85+
}
86+
87+
Error get_option(
88+
const char* backend_name,
89+
executorch::runtime::Span<executorch::runtime::BackendOption>
90+
backend_options) {
91+
auto backend_class = get_backend_class(backend_name);
92+
if (!backend_class) {
93+
return Error::NotFound;
94+
}
95+
BackendOptionContext backend_option_context;
96+
executorch::runtime::Span<BackendOption> backend_options_ref(
97+
backend_options.data(), backend_options.size());
98+
auto result =
99+
backend_class->get_option(backend_option_context, backend_options_ref);
100+
if (result != Error::Ok) {
101+
return result;
102+
}
103+
return Error::Ok;
104+
}
105+
69106
} // namespace ET_RUNTIME_NAMESPACE
70107
} // namespace executorch

runtime/backend/interface.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,34 @@ size_t get_num_registered_backends();
183183
*/
184184
Result<const char*> get_backend_name(size_t index);
185185

186+
/**
187+
* Sets backend options for a specific backend.
188+
*
189+
* @param backend_name The name of the backend to set options for
190+
* @param backend_options The backend option list containing the options
191+
* to set
192+
* @return Error::Ok on success, Error::NotFound if backend is not found, or
193+
* other error codes on failure
194+
*/
195+
Error set_option(
196+
const char* backend_name,
197+
const executorch::runtime::Span<executorch::runtime::BackendOption>
198+
backend_options);
199+
200+
/**
201+
* Retrieves backend options for a specific backend.
202+
*
203+
* @param backend_name The name of the backend to get options from
204+
* @param backend_options The backend option objects that will be filled with
205+
* the populated values from the backend
206+
* @return Error::Ok on success, Error::NotFound if backend is not found, or
207+
* other error codes on failure
208+
*/
209+
Error get_option(
210+
const char* backend_name,
211+
executorch::runtime::Span<executorch::runtime::BackendOption>
212+
backend_options);
213+
186214
} // namespace ET_RUNTIME_NAMESPACE
187215
} // namespace executorch
188216

runtime/backend/test/backend_interface_update_test.cpp

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
*/
88

99
#include <executorch/runtime/backend/interface.h>
10+
#include <executorch/runtime/backend/options.h>
1011
#include <executorch/runtime/platform/runtime.h>
1112

1213
#include <gtest/gtest.h>
14+
#include <memory>
1315

1416
using namespace ::testing;
1517
using executorch::runtime::ArrayRef;
@@ -61,7 +63,8 @@ class MockBackend : public BackendInterface {
6163
int success_update = 0;
6264
for (const auto& backend_option : backend_options) {
6365
if (strcmp(backend_option.key, "Backend") == 0) {
64-
if (std::holds_alternative<std::array<char, 256>>(
66+
if (std::holds_alternative<
67+
std::array<char, executorch::runtime::kMaxOptionValueLength>>(
6568
backend_option.value)) {
6669
// Store the value in our member variable
6770
const auto& arr =
@@ -285,3 +288,114 @@ TEST_F(BackendInterfaceUpdateTest, UpdateBetweenExecutes) {
285288
ASSERT_TRUE(mock_backend->target_backend.has_value());
286289
EXPECT_STREQ(mock_backend->target_backend.value().c_str(), "NPU");
287290
}
291+
292+
// Mock backend for testing
293+
class StubBackend : public BackendInterface {
294+
public:
295+
~StubBackend() override = default;
296+
297+
bool is_available() const override {
298+
return true;
299+
}
300+
301+
Result<DelegateHandle*> init(
302+
BackendInitContext& context,
303+
FreeableBuffer* processed,
304+
ArrayRef<CompileSpec> compile_specs) const override {
305+
return nullptr;
306+
}
307+
308+
Error execute(
309+
BackendExecutionContext& context,
310+
DelegateHandle* handle,
311+
EValue** args) const override {
312+
return Error::Ok;
313+
}
314+
315+
Error get_option(
316+
BackendOptionContext& context,
317+
executorch::runtime::Span<executorch::runtime::BackendOption>&
318+
backend_options) override {
319+
// For testing purposes, just record that get_option was called
320+
// and verify the input parameters
321+
get_option_called = true;
322+
get_option_call_count++;
323+
last_get_option_size = backend_options.size();
324+
325+
// Verify that the expected option key is present and modify the value
326+
for (size_t i = 0; i < backend_options.size(); ++i) {
327+
if (strcmp(backend_options[i].key, "NumberOfThreads") == 0) {
328+
// Set the value to what was stored by set_option
329+
backend_options[i].value = last_num_threads;
330+
found_expected_key = true;
331+
break;
332+
}
333+
}
334+
335+
return Error::Ok;
336+
}
337+
338+
Error set_option(
339+
BackendOptionContext& context,
340+
const executorch::runtime::Span<executorch::runtime::BackendOption>&
341+
backend_options) override {
342+
// Store the options for verification
343+
last_options_size = backend_options.size();
344+
if (backend_options.size() > 0) {
345+
for (const auto& option : backend_options) {
346+
if (strcmp(option.key, "NumberOfThreads") == 0) {
347+
if (auto* val = std::get_if<int>(&option.value)) {
348+
last_num_threads = *val;
349+
}
350+
}
351+
}
352+
}
353+
return Error::Ok;
354+
}
355+
356+
// Mutable for testing verification
357+
size_t last_options_size = 0;
358+
int last_num_threads = 0;
359+
bool get_option_called = false;
360+
int get_option_call_count = 0;
361+
size_t last_get_option_size = 0;
362+
bool found_expected_key = false;
363+
};
364+
365+
class BackendUpdateTest : public ::testing::Test {
366+
protected:
367+
void SetUp() override {
368+
// Since these tests cause ET_LOG to be called, the PAL must be initialized
369+
// first.
370+
executorch::runtime::runtime_init();
371+
372+
// Register the stub backend
373+
stub_backend = std::make_unique<StubBackend>();
374+
Backend backend_config{"StubBackend", stub_backend.get()};
375+
auto register_result = register_backend(backend_config);
376+
ASSERT_EQ(register_result, Error::Ok);
377+
}
378+
379+
std::unique_ptr<StubBackend> stub_backend;
380+
};
381+
382+
// Test basic string functionality
383+
TEST_F(BackendUpdateTest, TestSetGetOption) {
384+
BackendOptions<1> backend_options;
385+
int new_num_threads = 4;
386+
backend_options.set_option("NumberOfThreads", new_num_threads);
387+
388+
auto status = set_option("StubBackend", backend_options.view());
389+
ASSERT_EQ(status, Error::Ok);
390+
391+
// Set up the default option, which will be populuated by the get_option call
392+
BackendOption ref_backend_option{"NumberOfThreads", 0};
393+
status = get_option("StubBackend", ref_backend_option);
394+
395+
// Verify that the backend actually received the options
396+
ASSERT_TRUE(std::get<int>(ref_backend_option.value) == new_num_threads);
397+
398+
// Verify that the backend actually update the options
399+
ASSERT_EQ(stub_backend->last_options_size, 1);
400+
ASSERT_EQ(stub_backend->last_num_threads, new_num_threads);
401+
}

0 commit comments

Comments
 (0)