Skip to content

Commit 49ccb70

Browse files
dbortfacebook-github-bot
authored andcommitted
Hook up KernelRuntimeContext.fail()
Summary: This gives kernels a way to fail non-fatally. We still plan to add more features to `KernelRuntimeContext`, but they're lower priority right now. Reviewed By: JacobSzwejbka Differential Revision: D48198665 fbshipit-source-id: 59e22a568a658bab1358835e577840deda511465
1 parent 6e0e0cc commit 49ccb70

File tree

7 files changed

+357
-17
lines changed

7 files changed

+357
-17
lines changed

runtime/executor/method.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -786,10 +786,24 @@ Error Method::execute_instruction() {
786786
switch (instruction->instr_args_type()) {
787787
case executorch_flatbuffer::InstructionArguments::KernelCall: {
788788
EXECUTORCH_SCOPE_PROF("OPERATOR_CALL");
789-
KernelRuntimeContext context{};
789+
// TODO(T147221312): Also expose the temp allocator and tensor resizer
790+
// via the context.
791+
KernelRuntimeContext context;
790792
chain.kernels_[step_state_.instr_idx](
791793
context, chain.argument_lists_[step_state_.instr_idx].data());
792-
// TODO(T135464333): inspect runtime context for error state
794+
Error err = context.failure_state();
795+
if (err != Error::Ok) {
796+
ET_LOG(
797+
Error,
798+
"KernelCall failed at instruction %zu:%zu: 0x%x",
799+
step_state_.chain_idx,
800+
step_state_.instr_idx,
801+
(unsigned int)err);
802+
// TODO(T153804650): Consider logging the EValues to help with
803+
// debugging. This is a failure path, and it doesn't matter if it's a
804+
// little slow. Do the same for DelegateCall errors.
805+
return err;
806+
}
793807
} break;
794808
case executorch_flatbuffer::InstructionArguments::DelegateCall: {
795809
EXECUTORCH_SCOPE_PROF("DELEGATE_CALL");
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cctype>
10+
#include <filesystem>
11+
12+
#include <cstring>
13+
#include <memory>
14+
15+
#include <executorch/extension/data_loader/file_data_loader.h>
16+
#include <executorch/runtime/core/error.h>
17+
#include <executorch/runtime/core/result.h>
18+
#include <executorch/runtime/executor/method.h>
19+
#include <executorch/runtime/executor/program.h>
20+
#include <executorch/runtime/executor/test/managed_memory_manager.h>
21+
#include <executorch/runtime/kernel/kernel_runtime_context.h>
22+
#include <executorch/runtime/kernel/operator_registry.h>
23+
#include <executorch/runtime/platform/compiler.h>
24+
#include <executorch/runtime/platform/runtime.h>
25+
#include <executorch/util/util.h>
26+
27+
#include <gtest/gtest.h>
28+
29+
using namespace ::testing;
30+
using torch::executor::ArrayRef;
31+
using torch::executor::Error;
32+
using torch::executor::EValue;
33+
using torch::executor::FreeableBuffer;
34+
using torch::executor::Kernel;
35+
using torch::executor::KernelKey;
36+
using torch::executor::KernelRuntimeContext;
37+
using torch::executor::Method;
38+
using torch::executor::Program;
39+
using torch::executor::Result;
40+
using torch::executor::testing::ManagedMemoryManager;
41+
using torch::executor::util::FileDataLoader;
42+
43+
constexpr size_t kDefaultNonConstMemBytes = 32 * 1024U;
44+
constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024U;
45+
46+
/**
47+
* Used to control and observe the behavior of a kernel.
48+
*/
49+
struct KernelControl {
50+
public:
51+
// The number of times the kernel has been called.
52+
int call_count = 0;
53+
54+
// If true, the kernel should call `context.fail(error_to_set)`. If false,
55+
// the kernel should not call `context.fail()`.
56+
bool call_context_fail = true;
57+
58+
// The error value that the kernel should pass to `context.fail()` before
59+
// returning.
60+
Error fail_value = Error::Ok;
61+
62+
void reset() {
63+
call_count = 0;
64+
call_context_fail = false;
65+
fail_value = Error::Ok;
66+
}
67+
68+
/**
69+
* Registers a kernel that uses the singleton instance to record and control
70+
* its behavior.
71+
*/
72+
static void register_singleton() {
73+
if (registered_) {
74+
return;
75+
}
76+
77+
// This test helper installs itself as aten::add.out:
78+
//
79+
// add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) ->
80+
// Tensor(a!)
81+
//
82+
// The arguments are: `self, other, out, out` (we repeat the out argument in
83+
// the program). And since we traced using randn(2, 2), all the args are
84+
// Float with dim order (0, 1)
85+
86+
// Construct a kernel key with the following meta:
87+
// exec_aten::DimOrderType contiguous[] = {0, 1};
88+
// TensorMeta float_contiguous[] = {
89+
// TensorMeta(ScalarType::Float, contiguous), // self
90+
// TensorMeta(ScalarType::Float, contiguous), // other
91+
// TensorMeta(ScalarType::Float, contiguous), // out
92+
// TensorMeta(ScalarType::Float, contiguous)}; // out (repeated)
93+
KernelKey key = torch::executor::KernelKey(
94+
"v0/\x06;\x00\x01|\x06;\x00\x01|\x06;\x00\x01|\x06;\x00\x01\xff");
95+
Kernel kernel = torch::executor::Kernel(
96+
"aten::add.out", key, KernelControl::kernel_hook);
97+
Error err = torch::executor::register_kernels({kernel});
98+
EXPECT_EQ(err, Error::Ok);
99+
100+
registered_ = true;
101+
}
102+
103+
static KernelControl* singleton() {
104+
return &singleton_;
105+
}
106+
107+
private:
108+
/**
109+
* An OpFunction-compatible function that uses the singleton KernelControl
110+
* to record and determine its behavior.
111+
*/
112+
static void kernel_hook(
113+
KernelRuntimeContext& context,
114+
__ET_UNUSED EValue** args) {
115+
auto* control = KernelControl::singleton();
116+
control->call_count++;
117+
if (control->call_context_fail) {
118+
context.fail(control->fail_value);
119+
}
120+
}
121+
122+
static bool registered_;
123+
static KernelControl singleton_;
124+
};
125+
126+
bool KernelControl::registered_ = false;
127+
KernelControl KernelControl::singleton_;
128+
129+
class KernelIntegrationTest : public ::testing::Test {
130+
protected:
131+
void SetUp() override {
132+
// Register the controllable kernel hook.
133+
KernelControl::register_singleton();
134+
// Ensure that its state is clear.
135+
KernelControl::singleton()->reset();
136+
// Provide the singleton to the tests.
137+
control_ = KernelControl::singleton();
138+
139+
// Create a loader for the serialized ModuleAdd program.
140+
const char* path = std::getenv("ET_MODULE_ADD_PATH");
141+
Result<FileDataLoader> loader = FileDataLoader::From(path);
142+
ASSERT_EQ(loader.error(), Error::Ok);
143+
loader_ = std::make_unique<FileDataLoader>(std::move(loader.get()));
144+
145+
// Use it to load the program.
146+
Result<Program> program = Program::Load(
147+
loader_.get(), Program::Verification::InternalConsistency);
148+
ASSERT_EQ(program.error(), Error::Ok);
149+
program_ = std::make_unique<Program>(std::move(program.get()));
150+
151+
// Load the forward method.
152+
mmm_ = std::make_unique<ManagedMemoryManager>(
153+
kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes);
154+
Result<Method> method = program_->load_method("forward", &mmm_->get());
155+
ASSERT_EQ(method.error(), Error::Ok);
156+
method_ = std::make_unique<Method>(std::move(method.get()));
157+
158+
// Set up its inputs.
159+
inputs_ = torch::executor::util::PrepareInputTensors(*method_);
160+
}
161+
162+
void TearDown() override {
163+
torch::executor::util::FreeInputs(inputs_);
164+
inputs_ = {};
165+
}
166+
167+
private:
168+
// Must outlive program_
169+
std::unique_ptr<FileDataLoader> loader_;
170+
171+
// Must outlive method_
172+
std::unique_ptr<Program> program_;
173+
std::unique_ptr<ManagedMemoryManager> mmm_;
174+
ArrayRef<void*> inputs_;
175+
176+
protected:
177+
// An executable method that will call the kernel associated with control_.
178+
// Its inputs will have been allocated and initialized.
179+
std::unique_ptr<Method> method_;
180+
181+
// The KernelControl associated with method_.
182+
KernelControl* control_;
183+
};
184+
185+
TEST_F(KernelIntegrationTest, KernelHookIsCalled) {
186+
// Demonstrate that the kernel hook is called in the default state.
187+
EXPECT_EQ(control_->call_count, 0);
188+
Error err = method_->execute();
189+
EXPECT_EQ(err, Error::Ok);
190+
EXPECT_EQ(control_->call_count, 1);
191+
192+
// Calling it again bumps the count.
193+
err = method_->execute();
194+
EXPECT_EQ(err, Error::Ok);
195+
EXPECT_EQ(control_->call_count, 2);
196+
}
197+
198+
TEST_F(KernelIntegrationTest, FailurePropagates) {
199+
// Tell the kernel to fail.
200+
control_->call_context_fail = true;
201+
202+
// We should see the error from the kernel.
203+
control_->fail_value = Error::InvalidArgument;
204+
Error err = method_->execute();
205+
EXPECT_EQ(err, Error::InvalidArgument);
206+
EXPECT_EQ(control_->call_count, 1);
207+
208+
// Have it fail with a different error to show that it's not a coincidence.
209+
control_->fail_value = Error::MemoryAllocationFailed;
210+
err = method_->execute();
211+
EXPECT_EQ(err, Error::MemoryAllocationFailed);
212+
EXPECT_EQ(control_->call_count, 2);
213+
214+
// Returning an Ok does not cause the execution to fail.
215+
control_->fail_value = Error::Ok;
216+
err = method_->execute();
217+
EXPECT_EQ(err, Error::Ok);
218+
EXPECT_EQ(control_->call_count, 3);
219+
}

runtime/executor/test/targets.bzl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,24 @@ def define_common_targets(is_fbcode = False):
168168
env = modules_env,
169169
)
170170

171+
runtime.cxx_test(
172+
name = "kernel_integration_test",
173+
srcs = [
174+
"kernel_integration_test.cpp",
175+
],
176+
deps = [
177+
":managed_memory_manager",
178+
"//executorch/extension/data_loader:file_data_loader",
179+
"//executorch/runtime/core:core",
180+
"//executorch/runtime/executor:program",
181+
"//executorch/runtime/kernel:kernel_runtime_context",
182+
"//executorch/runtime/kernel:operator_registry",
183+
"//executorch/runtime/platform:platform",
184+
"//executorch/util:util",
185+
],
186+
env = modules_env,
187+
)
188+
171189
runtime.cxx_test(
172190
name = "backend_integration_test",
173191
srcs = [

runtime/kernel/kernel_runtime_context.h

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,48 @@
88

99
#pragma once
1010

11+
#include <executorch/runtime/core/error.h>
12+
#include <executorch/runtime/platform/compiler.h>
13+
1114
namespace torch {
1215
namespace executor {
1316

1417
/**
15-
* Bucket type abstraction that contains many elements of runtime state that
16-
* a kernel author may want available, but would otherwise be unable to access.
17-
*
18-
* Forwarded along to all operators when running in lean mode. NOTE: Will not be
19-
* forwarded to operators if running in ATen mode as those operators do not
20-
* expect to receive a KernelRuntimeContext and would not use it.
18+
* Runtime state and functionality for kernel implementations.
2119
*
22-
* This includes things like setting an error state, a scratch allocator for
23-
* operators that need more then constant space, and a TensorResizer for dynamic
24-
* shape tensors allowing programs to be more flexible with Tensor shape.
25-
*
26-
* TODO(T147221312): Define this interface
20+
* NOTE: Will not be passed to operators if running in ATen mode as those
21+
* operators do not expect to receive a KernelRuntimeContext argument.
2722
*/
28-
class KernelRuntimeContext {};
23+
class KernelRuntimeContext {
24+
public:
25+
/**
26+
* Tells the runtime that the kernel call has failed. Prefer this over
27+
* ET_CHECK_*(), which fatally panics the process/system.
28+
*
29+
* If this is not called, the runtime will treat the kernel call as
30+
* successful.
31+
*
32+
* This unusual error-propagation path is required because kernel signatures
33+
* do not have a natural way to return errors directly. They are generally
34+
* compatible with core PyTorch ATen kernel signatures, which use exceptions
35+
* to report errors. But, ExecuTorch does not use exceptions.
36+
*/
37+
void fail(Error error) {
38+
failure_state_ = error;
39+
}
40+
41+
/// Returns the current failure state.
42+
__ET_NODISCARD Error failure_state() const {
43+
return failure_state_;
44+
}
45+
46+
// TODO(T147221312): Add a way to allocate temporary memory.
47+
48+
// TODO(T147221312): Add a way to resize a tensor.
49+
50+
private:
51+
Error failure_state_ = Error::Ok;
52+
};
2953

3054
} // namespace executor
3155
} // namespace torch

runtime/kernel/targets.bzl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,17 @@ def define_common_targets():
3030
"kernel_runtime_context.h",
3131
],
3232
visibility = [
33-
"//executorch/kernels/prim_ops/...", # Contains kernels
34-
"//executorch/runtime/kernel/...",
3533
"//executorch/kernels/...",
3634
"//executorch/runtime/executor/...",
35+
"//executorch/runtime/kernel/...",
3736
"@EXECUTORCH_CLIENTS",
3837
],
3938
exported_deps = [
4039
"//executorch/runtime/core:core",
41-
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
40+
"//executorch/runtime/platform:platform",
41+
# TODO(T147221312): This will eventually depend on exec_aten
42+
# once KernelRuntimeContext support tensor resizing, which is
43+
# why this target supports aten mode.
4244
],
4345
)
4446

0 commit comments

Comments
 (0)