Skip to content

Commit 81dbd06

Browse files
David Linfacebook-github-bot
authored andcommitted
Add allocate_temp method to KernelRuntimeContext (#3209)
Summary: This adds an `allocate_temp` method to KernelRuntimeContext, and passes the temporary memory allocator from `execute_instruction`. The method returns a result that errors if the temporary `MemoryAllocator` was not provided or the memory could not be allocated. Reviewed By: dbort Differential Revision: D56421957
1 parent ee8c3a6 commit 81dbd06

File tree

5 files changed

+103
-8
lines changed

5 files changed

+103
-8
lines changed

runtime/core/memory_allocator.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class MemoryAllocator {
6363
/**
6464
* Allocates `size` bytes of memory.
6565
*
66-
* @param[in] size Number of memory chunks to allocate.
66+
* @param[in] size Number of bytes to allocate.
6767
* @param[in] alignment Minimum alignment for the returned pointer. Must be a
6868
* power of 2.
6969
*

runtime/executor/method.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,11 +1013,14 @@ Error Method::execute_instruction() {
10131013
EXECUTORCH_SCOPE_PROF("OPERATOR_CALL");
10141014
internal::EventTracerProfileScope event_tracer_scope =
10151015
internal::EventTracerProfileScope(event_tracer_, "OPERATOR_CALL");
1016-
// TODO(T147221312): Also expose the temp allocator and tensor resizer
1017-
// via the context.
1018-
KernelRuntimeContext context(event_tracer_);
1016+
// TODO(T147221312): Also expose tensor resizer via the context.
1017+
// The temp_allocator passed can be null, but calling allocate_temp will
1018+
// fail
1019+
KernelRuntimeContext context(
1020+
event_tracer_, memory_manager_->temp_allocator());
10191021
auto args = chain.argument_lists_[step_state_.instr_idx];
10201022
chain.kernels_[step_state_.instr_idx](context, args.data());
1023+
// We reset the temp_allocator after the switch statement
10211024
err = context.failure_state();
10221025
if (err != Error::Ok) {
10231026
// We know that instr_args_as_KernelCall is non-null because it was

runtime/kernel/kernel_runtime_context.h

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#include <executorch/runtime/core/error.h>
1212
#include <executorch/runtime/core/event_tracer_hooks.h>
13+
#include <executorch/runtime/core/memory_allocator.h>
14+
#include <executorch/runtime/core/result.h>
1315
#include <executorch/runtime/platform/compiler.h>
1416

1517
namespace torch {
@@ -24,10 +26,21 @@ namespace executor {
2426
class KernelRuntimeContext {
2527
public:
2628
/**
27-
* Construct a new kernel runtime context along with an optional event tracer.
29+
* Construct a new kernel runtime context.
30+
*
31+
* KernelRuntimeContext does not take ownership
32+
* of these pointers, so they must outlive the context instance.
33+
*
34+
* @param[in] event_tracer The optional EventTracer to use for
35+
* profiling/debugging
36+
* @param[in] temp_allocator The optional MemoryAllocator used to allocate
37+
* temporary memory for the kernel. If not provided, an error will be
38+
* returned when calling allocate_temp.
2839
*/
29-
KernelRuntimeContext(EventTracer* event_tracer = nullptr)
30-
: event_tracer_(event_tracer) {}
40+
KernelRuntimeContext(
41+
EventTracer* event_tracer = nullptr,
42+
MemoryAllocator* temp_allocator = nullptr)
43+
: event_tracer_(event_tracer), temp_allocator_(temp_allocator) {}
3144
/**
3245
* Tells the runtime that the kernel call has failed. Prefer this over
3346
* ET_CHECK_*(), which fatally panics the process/system.
@@ -60,12 +73,37 @@ class KernelRuntimeContext {
6073
return event_tracer_;
6174
}
6275

63-
// TODO(T147221312): Add a way to allocate temporary memory.
76+
/**
77+
* Allocates temporary memory that will be freed when the kernel returns. This
78+
* returns a pointer to the allocated memory or an error if the allocation
79+
* fails.
80+
*
81+
* @param[in] size Number of bytes to allocate.
82+
* @param[in] alignment Minimum alignment for the returned pointer. Must be a
83+
* power of 2.
84+
*
85+
* @returns A result object containing either a pointer to the allocated
86+
* memory or an error to indicate failure
87+
*/
88+
Result<void*> allocate_temp(
89+
size_t size,
90+
size_t alignment = MemoryAllocator::kDefaultAlignment) {
91+
ET_CHECK_OR_RETURN_ERROR(
92+
temp_allocator_ != nullptr, NotFound, "No temp allocator provided");
93+
void* temp_memory = temp_allocator_->allocate(size, alignment);
94+
ET_CHECK_OR_RETURN_ERROR(
95+
temp_memory != nullptr,
96+
MemoryAllocationFailed,
97+
"Failed to allocate temp memory. Bytes requested: %zu",
98+
size);
99+
return temp_memory;
100+
}
64101

65102
// TODO(T147221312): Add a way to resize a tensor.
66103

67104
private:
68105
EventTracer* event_tracer_ = nullptr;
106+
MemoryAllocator* temp_allocator_ = nullptr;
69107
Error failure_state_ = Error::Ok;
70108
};
71109

runtime/kernel/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def define_common_targets():
5555
exported_deps = [
5656
"//executorch/runtime/core:core",
5757
"//executorch/runtime/platform:platform",
58+
"//executorch/runtime/core:memory_allocator",
5859
"//executorch/runtime/core:event_tracer" + aten_suffix,
5960
# TODO(T147221312): This will eventually depend on exec_aten
6061
# once KernelRuntimeContext support tensor resizing, which is

runtime/kernel/test/kernel_runtime_context_test.cpp

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
using namespace ::testing;
1616
using torch::executor::Error;
1717
using torch::executor::KernelRuntimeContext;
18+
using torch::executor::MemoryAllocator;
19+
using torch::executor::Result;
1820

1921
class KernelRuntimeContextTest : public ::testing::Test {
2022
public:
@@ -23,6 +25,17 @@ class KernelRuntimeContextTest : public ::testing::Test {
2325
}
2426
};
2527

28+
class TestMemoryAllocator : public MemoryAllocator {
29+
public:
30+
TestMemoryAllocator(uint32_t size, uint8_t* base_address)
31+
: MemoryAllocator(size, base_address), last_seen_alignment(0) {}
32+
void* allocate(size_t size, size_t alignment) {
33+
last_seen_alignment = alignment;
34+
return MemoryAllocator::allocate(size, alignment);
35+
}
36+
size_t last_seen_alignment;
37+
};
38+
2639
TEST_F(KernelRuntimeContextTest, FailureStateDefaultsToOk) {
2740
KernelRuntimeContext context;
2841

@@ -47,3 +60,43 @@ TEST_F(KernelRuntimeContextTest, FailureStateReflectsFailure) {
4760
context.fail(Error::Ok);
4861
EXPECT_EQ(context.failure_state(), Error::Ok);
4962
}
63+
64+
TEST_F(KernelRuntimeContextTest, FailureNoMemoryAllocatorProvided) {
65+
KernelRuntimeContext context;
66+
Result<void*> allocated_memory = context.allocate_temp(4);
67+
EXPECT_EQ(allocated_memory.error(), Error::NotFound);
68+
}
69+
70+
TEST_F(KernelRuntimeContextTest, SuccessfulMemoryAllocation) {
71+
constexpr size_t temp_memory_allocator_pool_size = 4;
72+
auto temp_memory_allocator_pool =
73+
std::make_unique<uint8_t[]>(temp_memory_allocator_pool_size);
74+
MemoryAllocator temp_allocator(
75+
temp_memory_allocator_pool_size, temp_memory_allocator_pool.get());
76+
KernelRuntimeContext context(nullptr, &temp_allocator);
77+
Result<void*> allocated_memory = context.allocate_temp(4);
78+
EXPECT_EQ(allocated_memory.ok(), true);
79+
}
80+
81+
TEST_F(KernelRuntimeContextTest, FailureMemoryAllocationInsufficientSpace) {
82+
constexpr size_t temp_memory_allocator_pool_size = 4;
83+
auto temp_memory_allocator_pool =
84+
std::make_unique<uint8_t[]>(temp_memory_allocator_pool_size);
85+
MemoryAllocator temp_allocator(
86+
temp_memory_allocator_pool_size, temp_memory_allocator_pool.get());
87+
KernelRuntimeContext context(nullptr, &temp_allocator);
88+
Result<void*> allocated_memory = context.allocate_temp(8);
89+
EXPECT_EQ(allocated_memory.error(), Error::MemoryAllocationFailed);
90+
}
91+
92+
TEST_F(KernelRuntimeContextTest, MemoryAllocatorAlignmentPassed) {
93+
constexpr size_t temp_memory_allocator_pool_size = 4;
94+
auto temp_memory_allocator_pool =
95+
std::make_unique<uint8_t[]>(temp_memory_allocator_pool_size);
96+
TestMemoryAllocator temp_allocator(
97+
temp_memory_allocator_pool_size, temp_memory_allocator_pool.get());
98+
KernelRuntimeContext context(nullptr, &temp_allocator);
99+
Result<void*> allocated_memory = context.allocate_temp(4, 2);
100+
EXPECT_EQ(allocated_memory.ok(), true);
101+
EXPECT_EQ(temp_allocator.last_seen_alignment, 2);
102+
}

0 commit comments

Comments
 (0)