Skip to content

Commit 1162d32

Browse files
committed
Add get_option/set_option APIs
Differential Revision: [D76825663](https://our.internmc.facebook.com/intern/diff/D76825663/) ghstack-source-id: 290994800 Pull Request resolved: #11758
1 parent 921cfdc commit 1162d32

File tree

4 files changed

+283
-0
lines changed

4 files changed

+283
-0
lines changed

runtime/backend/backend_update.h

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
#include <executorch/runtime/backend/backend_options_map.h>
11+
#include <executorch/runtime/backend/backend_update_context.h>
12+
#include <executorch/runtime/backend/interface.h>
13+
#include <executorch/runtime/core/error.h>
14+
#include <cstddef>
15+
#include <cstring>
16+
17+
using executorch::runtime::BackendOptionsMap;
18+
19+
namespace executorch {
20+
namespace runtime {
21+
22+
Error get_option(
23+
executorch::runtime::Span<executorch::runtime::Entry> backend_options_map) {
24+
for (auto& entry : backend_options_map) {
25+
const char* backend_name = entry.backend_name;
26+
auto backend_options = entry.options;
27+
28+
auto backend_class = get_backend_class(backend_name);
29+
if (!backend_class) {
30+
return Error::NotFound;
31+
}
32+
33+
executorch::runtime::BackendUpdateContext backend_update_context;
34+
executorch::runtime::Span<BackendOption> backend_options_ref(
35+
backend_options.data(), backend_options.size());
36+
auto result =
37+
backend_class->get_option(backend_update_context, backend_options_ref);
38+
if (result != Error::Ok) {
39+
return result;
40+
}
41+
}
42+
return Error::Ok;
43+
}
44+
45+
Error set_option(
46+
const executorch::runtime::Span<executorch::runtime::Entry> backend_options_map) {
47+
for (const auto& entry : backend_options_map) {
48+
const char* backend_name = entry.backend_name;
49+
auto backend_options = entry.options;
50+
51+
auto backend_class = get_backend_class(backend_name);
52+
if (!backend_class) {
53+
return Error::NotFound;
54+
}
55+
56+
executorch::runtime::BackendUpdateContext backend_update_context;
57+
auto update_result =
58+
backend_class->set_option(backend_update_context, backend_options);
59+
if (update_result != Error::Ok) {
60+
return update_result;
61+
}
62+
}
63+
return Error::Ok;
64+
}
65+
66+
} // namespace runtime
67+
} // namespace executorch

runtime/backend/targets.bzl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,25 @@ def define_common_targets():
2828
],
2929
)
3030

31+
runtime.cxx_library(
32+
name = "backend_update" + aten_suffix,
33+
exported_headers = [
34+
"backend_update.h",
35+
],
36+
preprocessor_flags = ["-DUSE_ATEN_LIB"] if aten_mode else [],
37+
visibility = [
38+
"//executorch/...",
39+
"@EXECUTORCH_CLIENTS",
40+
],
41+
exported_deps = [
42+
"//executorch/runtime/core:core",
43+
"//executorch/runtime/core:evalue" + aten_suffix,
44+
"//executorch/runtime/core:event_tracer" + aten_suffix,
45+
":backend_options_map" + aten_suffix,
46+
":interface" + aten_suffix,
47+
],
48+
)
49+
3150
runtime.cxx_library(
3251
name = "interface" + aten_suffix,
3352
srcs = [
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/backend/backend_options.h>
10+
#include <executorch/runtime/backend/backend_options_map.h>
11+
#include <executorch/runtime/backend/backend_update.h>
12+
#include <executorch/runtime/backend/interface.h>
13+
#include <executorch/runtime/core/array_ref.h>
14+
#include <executorch/runtime/platform/runtime.h>
15+
#include <gtest/gtest.h>
16+
17+
using namespace ::testing;
18+
using executorch::runtime::ArrayRef;
19+
using executorch::runtime::Backend;
20+
using executorch::runtime::BackendExecutionContext;
21+
using executorch::runtime::BackendInitContext;
22+
using executorch::runtime::BackendInterface;
23+
using executorch::runtime::BackendOptions;
24+
using executorch::runtime::BackendOptionsMap;
25+
using executorch::runtime::BackendUpdateContext;
26+
using executorch::runtime::BoolKey;
27+
using executorch::runtime::CompileSpec;
28+
using executorch::runtime::DelegateHandle;
29+
using executorch::runtime::Error;
30+
using executorch::runtime::EValue;
31+
using executorch::runtime::FreeableBuffer;
32+
using executorch::runtime::IntKey;
33+
using executorch::runtime::OptionKey;
34+
using executorch::runtime::register_backend;
35+
using executorch::runtime::Result;
36+
using executorch::runtime::Span;
37+
using executorch::runtime::StrKey;
38+
39+
// Mock backend for testing
40+
class StubBackend : public BackendInterface {
41+
public:
42+
~StubBackend() override = default;
43+
44+
bool is_available() const override {
45+
return true;
46+
}
47+
48+
Result<DelegateHandle*> init(
49+
BackendInitContext& context,
50+
FreeableBuffer* processed,
51+
ArrayRef<CompileSpec> compile_specs) const override {
52+
return nullptr;
53+
}
54+
55+
Error execute(
56+
BackendExecutionContext& context,
57+
DelegateHandle* handle,
58+
EValue** args) const override {
59+
return Error::Ok;
60+
}
61+
62+
Error get_option(
63+
BackendUpdateContext& context,
64+
executorch::runtime::Span<executorch::runtime::BackendOption>&
65+
backend_options) override {
66+
// For testing purposes, just record that get_option was called
67+
// and verify the input parameters
68+
get_option_called = true;
69+
get_option_call_count++;
70+
last_get_option_size = backend_options.size();
71+
72+
// Verify that the expected option key is present and modify the value
73+
for (size_t i = 0; i < backend_options.size(); ++i) {
74+
if (strcmp(backend_options[i].key, "NumberOfThreads") == 0) {
75+
// Set the value to what was stored by set_option
76+
backend_options[i].value = last_num_threads;
77+
found_expected_key = true;
78+
break;
79+
}
80+
}
81+
82+
return Error::Ok;
83+
}
84+
85+
Error set_option(
86+
BackendUpdateContext& context,
87+
const Span<executorch::runtime::BackendOption>& backend_options) override {
88+
// Store the options for verification
89+
last_options_size = backend_options.size();
90+
if (backend_options.size() > 0) {
91+
for (const auto& option : backend_options) {
92+
if (strcmp(option.key, "NumberOfThreads") == 0) {
93+
if (auto* val = std::get_if<int>(&option.value)) {
94+
last_num_threads = *val;
95+
}
96+
}
97+
}
98+
}
99+
return Error::Ok;
100+
}
101+
102+
// Mutable for testing verification
103+
size_t last_options_size = 0;
104+
int last_num_threads = 0;
105+
bool get_option_called = false;
106+
int get_option_call_count = 0;
107+
size_t last_get_option_size = 0;
108+
bool found_expected_key = false;
109+
};
110+
111+
class BackendUpdateTest : public ::testing::Test {
112+
protected:
113+
void SetUp() override {
114+
// Since these tests cause ET_LOG to be called, the PAL must be initialized
115+
// first.
116+
executorch::runtime::runtime_init();
117+
118+
// Register the stub backend
119+
stub_backend = std::make_unique<StubBackend>();
120+
Backend backend_config{"StubBackend", stub_backend.get()};
121+
auto register_result = register_backend(backend_config);
122+
ASSERT_EQ(register_result, Error::Ok);
123+
}
124+
125+
std::unique_ptr<StubBackend> stub_backend;
126+
};
127+
128+
// Test basic string functionality
129+
TEST_F(BackendUpdateTest, TestSetOption) {
130+
BackendOptionsMap<3> map;
131+
BackendOptions<1> backend_options;
132+
int new_num_threads = 4;
133+
backend_options.set_option(IntKey("NumberOfThreads"), new_num_threads);
134+
map.add("StubBackend", backend_options.view());
135+
136+
auto status = set_option(map.entries());
137+
ASSERT_EQ(status, Error::Ok);
138+
139+
// Verify the map contains the expected data
140+
ASSERT_EQ(map.size(), 1);
141+
auto options = map.get("StubBackend");
142+
ASSERT_EQ(options.size(), 1);
143+
144+
// Verify that the backend actually received the options
145+
ASSERT_EQ(stub_backend->last_options_size, 1);
146+
ASSERT_EQ(stub_backend->last_num_threads, new_num_threads);
147+
}
148+
149+
// Test get_option functionality
150+
TEST_F(BackendUpdateTest, TestGetOption) {
151+
// First, set some options in the backend
152+
BackendOptionsMap<3> set_map;
153+
BackendOptions<1> set_backend_options;
154+
int expected_num_threads = 8;
155+
set_backend_options.set_option(
156+
IntKey("NumberOfThreads"), expected_num_threads);
157+
set_map.add("StubBackend", set_backend_options.view());
158+
159+
auto set_status = set_option(set_map.entries());
160+
ASSERT_EQ(set_status, Error::Ok);
161+
ASSERT_EQ(stub_backend->last_num_threads, expected_num_threads);
162+
163+
// Reset get_option tracking variables
164+
stub_backend->get_option_called = false;
165+
stub_backend->get_option_call_count = 0;
166+
stub_backend->found_expected_key = false;
167+
168+
// Now create a map with options for get_option to process
169+
BackendOptionsMap<3> get_map;
170+
BackendOptions<1> get_backend_options;
171+
get_backend_options.set_option(IntKey("NumberOfThreads"), 0);
172+
get_map.add("StubBackend", get_backend_options.view());
173+
174+
// Call get_option to test the API
175+
auto get_status = get_option(get_map.entries());
176+
ASSERT_EQ(get_status, Error::Ok);
177+
178+
ASSERT_TRUE(
179+
std::get<int>(get_map.entries()[0].options[0].value) ==
180+
expected_num_threads);
181+
182+
// // Verify that the backend's get_option method was called correctly
183+
// ASSERT_TRUE(stub_backend->get_option_called);
184+
// ASSERT_EQ(stub_backend->get_option_call_count, 1);
185+
// ASSERT_EQ(stub_backend->last_get_option_size, 1);
186+
// ASSERT_TRUE(stub_backend->found_expected_key);
187+
}

runtime/backend/test/targets.bzl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ def define_common_targets():
2424
],
2525
)
2626

27+
runtime.cxx_test(
28+
name = "backend_update_test",
29+
srcs = ["backend_update_test.cpp"],
30+
deps = [
31+
"//executorch/runtime/core:core",
32+
"//executorch/runtime/backend:backend_options_map",
33+
"//executorch/runtime/backend:backend_update",
34+
],
35+
)
36+
2737
runtime.cxx_test(
2838
name = "backend_interface_update_test",
2939
srcs = ["backend_interface_update_test.cpp"],

0 commit comments

Comments
 (0)