Skip to content

Add allocate_temp method to KernelRuntimeContext #3209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion runtime/core/memory_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class MemoryAllocator {
/**
* Allocates `size` bytes of memory.
*
* @param[in] size Number of memory chunks to allocate.
* @param[in] size Number of bytes to allocate.
* @param[in] alignment Minimum alignment for the returned pointer. Must be a
* power of 2.
*
Expand Down
9 changes: 6 additions & 3 deletions runtime/executor/method.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1013,11 +1013,14 @@ Error Method::execute_instruction() {
EXECUTORCH_SCOPE_PROF("OPERATOR_CALL");
internal::EventTracerProfileScope event_tracer_scope =
internal::EventTracerProfileScope(event_tracer_, "OPERATOR_CALL");
// TODO(T147221312): Also expose the temp allocator and tensor resizer
// via the context.
KernelRuntimeContext context(event_tracer_);
// TODO(T147221312): Also expose tensor resizer via the context.
// The temp_allocator passed can be null, but calling allocate_temp will
// fail
KernelRuntimeContext context(
event_tracer_, memory_manager_->temp_allocator());
auto args = chain.argument_lists_[step_state_.instr_idx];
chain.kernels_[step_state_.instr_idx](context, args.data());
// We reset the temp_allocator after the switch statement
err = context.failure_state();
if (err != Error::Ok) {
// We know that instr_args_as_KernelCall is non-null because it was
Expand Down
46 changes: 42 additions & 4 deletions runtime/kernel/kernel_runtime_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/event_tracer_hooks.h>
#include <executorch/runtime/core/memory_allocator.h>
#include <executorch/runtime/core/result.h>
#include <executorch/runtime/platform/compiler.h>

namespace torch {
Expand All @@ -24,10 +26,21 @@ namespace executor {
class KernelRuntimeContext {
public:
/**
* Construct a new kernel runtime context along with an optional event tracer.
* Construct a new kernel runtime context.
*
* KernelRuntimeContext does not take ownership
* of these pointers, so they must outlive the context instance.
*
* @param[in] event_tracer The optional EventTracer to use for
* profiling/debugging
* @param[in] temp_allocator The optional MemoryAllocator used to allocate
* temporary memory for the kernel. If not provided, an error will be
* returned when calling allocate_temp.
*/
KernelRuntimeContext(EventTracer* event_tracer = nullptr)
: event_tracer_(event_tracer) {}
KernelRuntimeContext(
EventTracer* event_tracer = nullptr,
MemoryAllocator* temp_allocator = nullptr)
: event_tracer_(event_tracer), temp_allocator_(temp_allocator) {}
/**
* Tells the runtime that the kernel call has failed. Prefer this over
* ET_CHECK_*(), which fatally panics the process/system.
Expand Down Expand Up @@ -60,12 +73,37 @@ class KernelRuntimeContext {
return event_tracer_;
}

// TODO(T147221312): Add a way to allocate temporary memory.
/**
* Allocates temporary memory that will be freed when the kernel returns. This
* returns a pointer to the allocated memory or an error if the allocation
* fails.
*
* @param[in] size Number of bytes to allocate.
* @param[in] alignment Minimum alignment for the returned pointer. Must be a
* power of 2.
*
* @returns A result object containing either a pointer to the allocated
* memory or an error to indicate failure
*/
Result<void*> allocate_temp(
size_t size,
size_t alignment = MemoryAllocator::kDefaultAlignment) {
ET_CHECK_OR_RETURN_ERROR(
temp_allocator_ != nullptr, NotFound, "No temp allocator provided");
void* temp_memory = temp_allocator_->allocate(size, alignment);
ET_CHECK_OR_RETURN_ERROR(
temp_memory != nullptr,
MemoryAllocationFailed,
"Failed to allocate temp memory. Bytes requested: %zu",
size);
return temp_memory;
}

// TODO(T147221312): Add a way to resize a tensor.

private:
EventTracer* event_tracer_ = nullptr;
MemoryAllocator* temp_allocator_ = nullptr;
Error failure_state_ = Error::Ok;
};

Expand Down
1 change: 1 addition & 0 deletions runtime/kernel/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def define_common_targets():
exported_deps = [
"//executorch/runtime/core:core",
"//executorch/runtime/platform:platform",
"//executorch/runtime/core:memory_allocator",
"//executorch/runtime/core:event_tracer" + aten_suffix,
# TODO(T147221312): This will eventually depend on exec_aten
# once KernelRuntimeContext support tensor resizing, which is
Expand Down
53 changes: 53 additions & 0 deletions runtime/kernel/test/kernel_runtime_context_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
using namespace ::testing;
using torch::executor::Error;
using torch::executor::KernelRuntimeContext;
using torch::executor::MemoryAllocator;
using torch::executor::Result;

class KernelRuntimeContextTest : public ::testing::Test {
public:
Expand All @@ -23,6 +25,17 @@ class KernelRuntimeContextTest : public ::testing::Test {
}
};

class TestMemoryAllocator : public MemoryAllocator {
public:
TestMemoryAllocator(uint32_t size, uint8_t* base_address)
: MemoryAllocator(size, base_address), last_seen_alignment(0) {}
void* allocate(size_t size, size_t alignment) override {
last_seen_alignment = alignment;
return MemoryAllocator::allocate(size, alignment);
}
size_t last_seen_alignment;
};

TEST_F(KernelRuntimeContextTest, FailureStateDefaultsToOk) {
KernelRuntimeContext context;

Expand All @@ -47,3 +60,43 @@ TEST_F(KernelRuntimeContextTest, FailureStateReflectsFailure) {
context.fail(Error::Ok);
EXPECT_EQ(context.failure_state(), Error::Ok);
}

TEST_F(KernelRuntimeContextTest, FailureNoMemoryAllocatorProvided) {
KernelRuntimeContext context;
Result<void*> allocated_memory = context.allocate_temp(4);
EXPECT_EQ(allocated_memory.error(), Error::NotFound);
}

TEST_F(KernelRuntimeContextTest, SuccessfulMemoryAllocation) {
constexpr size_t temp_memory_allocator_pool_size = 4;
auto temp_memory_allocator_pool =
std::make_unique<uint8_t[]>(temp_memory_allocator_pool_size);
MemoryAllocator temp_allocator(
temp_memory_allocator_pool_size, temp_memory_allocator_pool.get());
KernelRuntimeContext context(nullptr, &temp_allocator);
Result<void*> allocated_memory = context.allocate_temp(4);
EXPECT_EQ(allocated_memory.ok(), true);
}

TEST_F(KernelRuntimeContextTest, FailureMemoryAllocationInsufficientSpace) {
constexpr size_t temp_memory_allocator_pool_size = 4;
auto temp_memory_allocator_pool =
std::make_unique<uint8_t[]>(temp_memory_allocator_pool_size);
MemoryAllocator temp_allocator(
temp_memory_allocator_pool_size, temp_memory_allocator_pool.get());
KernelRuntimeContext context(nullptr, &temp_allocator);
Result<void*> allocated_memory = context.allocate_temp(8);
EXPECT_EQ(allocated_memory.error(), Error::MemoryAllocationFailed);
}

TEST_F(KernelRuntimeContextTest, MemoryAllocatorAlignmentPassed) {
constexpr size_t temp_memory_allocator_pool_size = 4;
auto temp_memory_allocator_pool =
std::make_unique<uint8_t[]>(temp_memory_allocator_pool_size);
TestMemoryAllocator temp_allocator(
temp_memory_allocator_pool_size, temp_memory_allocator_pool.get());
KernelRuntimeContext context(nullptr, &temp_allocator);
Result<void*> allocated_memory = context.allocate_temp(4, 2);
EXPECT_EQ(allocated_memory.ok(), true);
EXPECT_EQ(temp_allocator.last_seen_alignment, 2);
}