|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <cctype> |
| 10 | +#include <filesystem> |
| 11 | + |
| 12 | +#include <cstring> |
| 13 | +#include <memory> |
| 14 | + |
| 15 | +#include <executorch/extension/data_loader/file_data_loader.h> |
| 16 | +#include <executorch/runtime/core/error.h> |
| 17 | +#include <executorch/runtime/core/result.h> |
| 18 | +#include <executorch/runtime/executor/method.h> |
| 19 | +#include <executorch/runtime/executor/program.h> |
| 20 | +#include <executorch/runtime/executor/test/managed_memory_manager.h> |
| 21 | +#include <executorch/runtime/kernel/kernel_runtime_context.h> |
| 22 | +#include <executorch/runtime/kernel/operator_registry.h> |
| 23 | +#include <executorch/runtime/platform/compiler.h> |
| 24 | +#include <executorch/runtime/platform/runtime.h> |
| 25 | +#include <executorch/util/util.h> |
| 26 | + |
| 27 | +#include <gtest/gtest.h> |
| 28 | + |
| 29 | +using namespace ::testing; |
| 30 | +using torch::executor::ArrayRef; |
| 31 | +using torch::executor::Error; |
| 32 | +using torch::executor::EValue; |
| 33 | +using torch::executor::FreeableBuffer; |
| 34 | +using torch::executor::Kernel; |
| 35 | +using torch::executor::KernelKey; |
| 36 | +using torch::executor::KernelRuntimeContext; |
| 37 | +using torch::executor::Method; |
| 38 | +using torch::executor::Program; |
| 39 | +using torch::executor::Result; |
| 40 | +using torch::executor::testing::ManagedMemoryManager; |
| 41 | +using torch::executor::util::FileDataLoader; |
| 42 | + |
| 43 | +constexpr size_t kDefaultNonConstMemBytes = 32 * 1024U; |
| 44 | +constexpr size_t kDefaultRuntimeMemBytes = 32 * 1024U; |
| 45 | + |
| 46 | +/** |
| 47 | + * Used to control and observe the behavior of a kernel. |
| 48 | + */ |
| 49 | +struct KernelControl { |
| 50 | + public: |
| 51 | + // The number of times the kernel has been called. |
| 52 | + int call_count = 0; |
| 53 | + |
| 54 | + // If true, the kernel should call `context.fail(error_to_set)`. If false, |
| 55 | + // the kernel should not call `context.fail()`. |
| 56 | + bool call_context_fail = true; |
| 57 | + |
| 58 | + // The error value that the kernel should pass to `context.fail()` before |
| 59 | + // returning. |
| 60 | + Error fail_value = Error::Ok; |
| 61 | + |
| 62 | + void reset() { |
| 63 | + call_count = 0; |
| 64 | + call_context_fail = false; |
| 65 | + fail_value = Error::Ok; |
| 66 | + } |
| 67 | + |
| 68 | + /** |
| 69 | + * Registers a kernel that uses the singleton instance to record and control |
| 70 | + * its behavior. |
| 71 | + */ |
| 72 | + static void register_singleton() { |
| 73 | + if (registered_) { |
| 74 | + return; |
| 75 | + } |
| 76 | + |
| 77 | + // This test helper installs itself as aten::add.out: |
| 78 | + // |
| 79 | + // add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> |
| 80 | + // Tensor(a!) |
| 81 | + // |
| 82 | + // The arguments are: `self, other, out, out` (we repeat the out argument in |
| 83 | + // the program). And since we traced using randn(2, 2), all the args are |
| 84 | + // Float with dim order (0, 1) |
| 85 | + |
| 86 | + // Construct a kernel key with the following meta: |
| 87 | + // exec_aten::DimOrderType contiguous[] = {0, 1}; |
| 88 | + // TensorMeta float_contiguous[] = { |
| 89 | + // TensorMeta(ScalarType::Float, contiguous), // self |
| 90 | + // TensorMeta(ScalarType::Float, contiguous), // other |
| 91 | + // TensorMeta(ScalarType::Float, contiguous), // out |
| 92 | + // TensorMeta(ScalarType::Float, contiguous)}; // out (repeated) |
| 93 | + KernelKey key = torch::executor::KernelKey( |
| 94 | + "v0/\x06;\x00\x01|\x06;\x00\x01|\x06;\x00\x01|\x06;\x00\x01\xff"); |
| 95 | + Kernel kernel = torch::executor::Kernel( |
| 96 | + "aten::add.out", key, KernelControl::kernel_hook); |
| 97 | + Error err = torch::executor::register_kernels({kernel}); |
| 98 | + EXPECT_EQ(err, Error::Ok); |
| 99 | + |
| 100 | + registered_ = true; |
| 101 | + } |
| 102 | + |
| 103 | + static KernelControl* singleton() { |
| 104 | + return &singleton_; |
| 105 | + } |
| 106 | + |
| 107 | + private: |
| 108 | + /** |
| 109 | + * An OpFunction-compatible function that uses the singleton KernelControl |
| 110 | + * to record and determine its behavior. |
| 111 | + */ |
| 112 | + static void kernel_hook( |
| 113 | + KernelRuntimeContext& context, |
| 114 | + __ET_UNUSED EValue** args) { |
| 115 | + auto* control = KernelControl::singleton(); |
| 116 | + control->call_count++; |
| 117 | + if (control->call_context_fail) { |
| 118 | + context.fail(control->fail_value); |
| 119 | + } |
| 120 | + } |
| 121 | + |
| 122 | + static bool registered_; |
| 123 | + static KernelControl singleton_; |
| 124 | +}; |
| 125 | + |
| 126 | +bool KernelControl::registered_ = false; |
| 127 | +KernelControl KernelControl::singleton_; |
| 128 | + |
| 129 | +class KernelIntegrationTest : public ::testing::Test { |
| 130 | + protected: |
| 131 | + void SetUp() override { |
| 132 | + // Register the controllable kernel hook. |
| 133 | + KernelControl::register_singleton(); |
| 134 | + // Ensure that its state is clear. |
| 135 | + KernelControl::singleton()->reset(); |
| 136 | + // Provide the singleton to the tests. |
| 137 | + control_ = KernelControl::singleton(); |
| 138 | + |
| 139 | + // Create a loader for the serialized ModuleAdd program. |
| 140 | + const char* path = std::getenv("ET_MODULE_ADD_PATH"); |
| 141 | + Result<FileDataLoader> loader = FileDataLoader::From(path); |
| 142 | + ASSERT_EQ(loader.error(), Error::Ok); |
| 143 | + loader_ = std::make_unique<FileDataLoader>(std::move(loader.get())); |
| 144 | + |
| 145 | + // Use it to load the program. |
| 146 | + Result<Program> program = Program::Load( |
| 147 | + loader_.get(), Program::Verification::InternalConsistency); |
| 148 | + ASSERT_EQ(program.error(), Error::Ok); |
| 149 | + program_ = std::make_unique<Program>(std::move(program.get())); |
| 150 | + |
| 151 | + // Load the forward method. |
| 152 | + mmm_ = std::make_unique<ManagedMemoryManager>( |
| 153 | + kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); |
| 154 | + Result<Method> method = program_->load_method("forward", &mmm_->get()); |
| 155 | + ASSERT_EQ(method.error(), Error::Ok); |
| 156 | + method_ = std::make_unique<Method>(std::move(method.get())); |
| 157 | + |
| 158 | + // Set up its inputs. |
| 159 | + inputs_ = torch::executor::util::PrepareInputTensors(*method_); |
| 160 | + } |
| 161 | + |
| 162 | + void TearDown() override { |
| 163 | + torch::executor::util::FreeInputs(inputs_); |
| 164 | + inputs_ = {}; |
| 165 | + } |
| 166 | + |
| 167 | + private: |
| 168 | + // Must outlive program_ |
| 169 | + std::unique_ptr<FileDataLoader> loader_; |
| 170 | + |
| 171 | + // Must outlive method_ |
| 172 | + std::unique_ptr<Program> program_; |
| 173 | + std::unique_ptr<ManagedMemoryManager> mmm_; |
| 174 | + ArrayRef<void*> inputs_; |
| 175 | + |
| 176 | + protected: |
| 177 | + // An executable method that will call the kernel associated with control_. |
| 178 | + // Its inputs will have been allocated and initialized. |
| 179 | + std::unique_ptr<Method> method_; |
| 180 | + |
| 181 | + // The KernelControl associated with method_. |
| 182 | + KernelControl* control_; |
| 183 | +}; |
| 184 | + |
| 185 | +TEST_F(KernelIntegrationTest, KernelHookIsCalled) { |
| 186 | + // Demonstrate that the kernel hook is called in the default state. |
| 187 | + EXPECT_EQ(control_->call_count, 0); |
| 188 | + Error err = method_->execute(); |
| 189 | + EXPECT_EQ(err, Error::Ok); |
| 190 | + EXPECT_EQ(control_->call_count, 1); |
| 191 | + |
| 192 | + // Calling it again bumps the count. |
| 193 | + err = method_->execute(); |
| 194 | + EXPECT_EQ(err, Error::Ok); |
| 195 | + EXPECT_EQ(control_->call_count, 2); |
| 196 | +} |
| 197 | + |
| 198 | +TEST_F(KernelIntegrationTest, FailurePropagates) { |
| 199 | + // Tell the kernel to fail. |
| 200 | + control_->call_context_fail = true; |
| 201 | + |
| 202 | + // We should see the error from the kernel. |
| 203 | + control_->fail_value = Error::InvalidArgument; |
| 204 | + Error err = method_->execute(); |
| 205 | + EXPECT_EQ(err, Error::InvalidArgument); |
| 206 | + EXPECT_EQ(control_->call_count, 1); |
| 207 | + |
| 208 | + // Have it fail with a different error to show that it's not a coincidence. |
| 209 | + control_->fail_value = Error::MemoryAllocationFailed; |
| 210 | + err = method_->execute(); |
| 211 | + EXPECT_EQ(err, Error::MemoryAllocationFailed); |
| 212 | + EXPECT_EQ(control_->call_count, 2); |
| 213 | + |
| 214 | + // Returning an Ok does not cause the execution to fail. |
| 215 | + control_->fail_value = Error::Ok; |
| 216 | + err = method_->execute(); |
| 217 | + EXPECT_EQ(err, Error::Ok); |
| 218 | + EXPECT_EQ(control_->call_count, 3); |
| 219 | +} |
0 commit comments