Skip to content

Commit 58c8c92

Browse files
dbortfacebook-github-bot
authored andcommitted
Make HierarchicalAllocator use buffers instead of MemoryAllocators (#387)
Summary: Pull Request resolved: #387 The only operations that HierarchicalAllocator performs on its MemoryAllocator entries are `base_address()` and `size()`, which means that they're not really allocators, they're just buffers. Move to using an array of simple "pointer and size" elements (represented by Spans). This will ultimately let us remove `base_address()` and `size()` from `MemoryAllocator`, which are incompatible with dynamic subclasses like `MallocMemoryAllocator`. To help demonstrate that the new version works, make `ManagedMemoryManager` use spans. This will cause a few dozen tests around the tree to start using spans as well. ghstack-source-id: 201128784 exported-using-ghexport Reviewed By: JacobSzwejbka Differential Revision: D49344931 fbshipit-source-id: f2fec0655e0c8a88478d98a38506e436b2702fc0
1 parent 650333e commit 58c8c92

File tree

3 files changed

+128
-49
lines changed

3 files changed

+128
-49
lines changed

runtime/core/hierarchical_allocator.h

Lines changed: 62 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <executorch/runtime/core/memory_allocator.h>
1212
#include <executorch/runtime/core/result.h>
13+
#include <executorch/runtime/core/span.h>
1314
#include <executorch/runtime/platform/assert.h>
1415
#include <executorch/runtime/platform/compiler.h>
1516
#include <executorch/runtime/platform/log.h>
@@ -18,75 +19,91 @@
1819
namespace torch {
1920
namespace executor {
2021

21-
// A group of allocators that can be used to represent a device's memory
22-
// hierarchy.
23-
class HierarchicalAllocator {
22+
/**
23+
* A group of buffers that can be used to represent a device's memory hierarchy.
24+
*/
25+
class HierarchicalAllocator final {
2426
public:
25-
// Constructs a new hierarchycal allocator with the given array of allocators.
26-
// Memory IDs are assigned based on the index in the 'allocators' array. E.g.
27-
// the first allocator in the array will have a memory ID of 0.
28-
HierarchicalAllocator(uint32_t n_allocators, MemoryAllocator* allocators)
29-
: n_allocators_(n_allocators), allocators_(allocators) {}
27+
/**
28+
* Constructs a new hierarchical allocator with the given array of buffers.
29+
*
30+
* - Memory IDs are based on the index into `buffers`: `buffers[N]` will have
31+
* a memory ID of `N`.
32+
* - `buffers.size()` must be >= `MethodMeta::num_non_const_buffers()`.
33+
* - `buffers[N].size()` must be >= `MethodMeta::non_const_buffer_size(N)`.
34+
*/
35+
explicit HierarchicalAllocator(Span<Span<uint8_t>> buffers)
36+
: buffers_(buffers) {}
37+
38+
/**
39+
* DEPRECATED: Use spans instead.
40+
*/
41+
__ET_DEPRECATED HierarchicalAllocator(
42+
uint32_t n_allocators,
43+
MemoryAllocator* allocators)
44+
: buffers_(to_spans(n_allocators, allocators)) {}
3045

3146
/**
3247
* Returns the address at the byte offset `offset_bytes` from the given
33-
* allocator's base address, which should have at least `size_bytes` of memory
34-
* available inside the allocator's buffer.
35-
*
36-
* This is useful to point an object to this address when such information has
37-
* been predetermined. This method assumes that the given memory's allocator
38-
* has already reserved enough memory (i.e. there's no actual allocation call
39-
* to the underlying memory allocator).
48+
* buffer's base address, which points to at least `size_bytes` of memory.
4049
*
41-
* @param[in] memory_id The ID of the allocator in the hierarchy.
42-
* @param[in] offset_bytes The offset in bytes into the memory of the
43-
* specified allocator.
50+
* @param[in] memory_id The ID of the buffer in the hierarchy.
51+
* @param[in] offset_bytes The offset in bytes into the specified buffer.
4452
* @param[in] size_bytes The amount of memory that should be available at
4553
* the offset.
4654
*
4755
* @returns On success, the address of the requested byte offset into the
48-
* specified allocator. On failure, a non-Ok Error.
56+
* specified buffer. On failure, a non-Ok Error.
4957
*/
5058
__ET_NODISCARD Result<void*> get_offset_address(
5159
uint32_t memory_id,
5260
size_t offset_bytes,
5361
size_t size_bytes) {
54-
Result<MemoryAllocator*> allocator_result = get_allocator(memory_id);
55-
if (!allocator_result.ok()) {
56-
return allocator_result.error();
57-
}
58-
auto allocator = allocator_result.get();
5962
ET_CHECK_OR_RETURN_ERROR(
60-
offset_bytes + size_bytes <= allocator->size(),
63+
memory_id < buffers_.size(),
64+
InvalidArgument,
65+
"id %" PRIu32 " >= %zu",
66+
memory_id,
67+
buffers_.size());
68+
Span<uint8_t> buffer = buffers_[memory_id];
69+
ET_CHECK_OR_RETURN_ERROR(
70+
offset_bytes + size_bytes <= buffer.size(),
6171
MemoryAllocationFailed,
62-
"offset_bytes (%zu) + size_bytes (%zu) >= allocator size (%" PRIu32
63-
") for memory_id %" PRIu32,
72+
"offset_bytes (%zu) + size_bytes (%zu) >= allocator size (%zu) "
73+
"for memory_id %" PRIu32,
6474
offset_bytes,
6575
size_bytes,
66-
allocator->size(),
76+
buffer.size(),
6777
memory_id);
68-
return allocator->base_address() + offset_bytes;
78+
return buffer.data() + offset_bytes;
6979
}
7080

71-
virtual ~HierarchicalAllocator() {}
72-
7381
private:
74-
/// Returns the memory allocator for the given 'memory_id' in the hierarchy.
75-
Result<MemoryAllocator*> get_allocator(uint32_t memory_id) const {
76-
ET_CHECK_OR_RETURN_ERROR(
77-
memory_id < n_allocators_,
78-
InvalidArgument,
79-
"Memory id %" PRIu32 " >= n_allocators_ %" PRIu32,
80-
memory_id,
81-
n_allocators_);
82-
return &allocators_[memory_id];
82+
// TODO(T162089316): Remove the span array and to_spans once all users move to
83+
// spans. This array is necessary to hold the pointers and sizes that were
84+
// originally provided as MemoryAllocator instances.
85+
static constexpr size_t kSpanArraySize = 16;
86+
// NOTE: span_array_ must be declared before buffers_ so that it isn't
87+
// re-initialized to zeros after initializing buffers_.
88+
Span<uint8_t> span_array_[kSpanArraySize];
89+
Span<Span<uint8_t>> to_spans(
90+
uint32_t n_allocators,
91+
MemoryAllocator* allocators) {
92+
ET_CHECK_MSG(
93+
n_allocators <= kSpanArraySize,
94+
"n_allocators %" PRIu32 " > %zu",
95+
n_allocators,
96+
kSpanArraySize);
97+
for (uint32_t i = 0; i < n_allocators; ++i) {
98+
span_array_[i] =
99+
Span<uint8_t>(allocators[i].base_address(), allocators[i].size());
100+
}
101+
return {span_array_, n_allocators};
83102
}
84103

85-
// The HierarchicalAllocator holds n_allocators_ MemoryAllocators.
86-
uint32_t n_allocators_;
87-
88-
// Memory allocators as an array, each ID corresponds to their index.
89-
MemoryAllocator* allocators_;
104+
/// The underlying buffers.
105+
Span<Span<uint8_t>> buffers_;
90106
};
107+
91108
} // namespace executor
92109
} // namespace torch

runtime/core/test/hierarchical_allocator_test.cpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <cstdint>
10+
911
#include <executorch/runtime/core/hierarchical_allocator.h>
1012
#include <executorch/runtime/core/memory_allocator.h>
13+
#include <executorch/runtime/core/span.h>
1114
#include <executorch/runtime/platform/runtime.h>
1215
#include <executorch/test/utils/alignment.h>
1316

@@ -18,6 +21,7 @@ using torch::executor::Error;
1821
using torch::executor::HierarchicalAllocator;
1922
using torch::executor::MemoryAllocator;
2023
using torch::executor::Result;
24+
using torch::executor::Span;
2125

2226
class HierarchicalAllocatorTest : public ::testing::Test {
2327
protected:
@@ -29,6 +33,63 @@ class HierarchicalAllocatorTest : public ::testing::Test {
2933
};
3034

3135
TEST_F(HierarchicalAllocatorTest, Smoke) {
36+
constexpr size_t n_buffers = 2;
37+
constexpr size_t size0 = 4;
38+
constexpr size_t size1 = 8;
39+
uint8_t mem0[size0];
40+
uint8_t mem1[size1];
41+
Span<uint8_t> buffers[n_buffers]{
42+
{mem0, size0},
43+
{mem1, size1},
44+
};
45+
46+
HierarchicalAllocator allocator({buffers, n_buffers});
47+
48+
// get_offset_address() success cases
49+
{
50+
// Total size is 4, so off=0 + size=2 fits.
51+
Result<void*> address = allocator.get_offset_address(
52+
/*memory_id=*/0, /*offset_bytes=*/0, /*size_bytes=*/2);
53+
ASSERT_EQ(address.error(), Error::Ok);
54+
ASSERT_NE(address.get(), nullptr);
55+
ASSERT_EQ(address.get(), mem0);
56+
}
57+
{
58+
// Total size is 8, so off=4 + size=4 fits exactly.
59+
Result<void*> address = allocator.get_offset_address(
60+
/*memory_id=*/1, /*offset_bytes=*/4, /*size_bytes=*/4);
61+
ASSERT_EQ(address.error(), Error::Ok);
62+
ASSERT_NE(address.get(), nullptr);
63+
ASSERT_EQ(address.get(), mem1 + 4);
64+
}
65+
66+
// get_offset_address() failure cases
67+
{
68+
// Total size is 4, so off=0 + size=5 is too large.
69+
Result<void*> address = allocator.get_offset_address(
70+
/*memory_id=*/0, /*offset_bytes=*/4, /*size_bytes=*/5);
71+
ASSERT_FALSE(address.ok());
72+
ASSERT_NE(address.error(), Error::Ok);
73+
}
74+
{
75+
// Total size is 4, so off=8 + size=0 is off the end.
76+
Result<void*> address = allocator.get_offset_address(
77+
/*memory_id=*/0, /*offset_bytes=*/8, /*size_bytes=*/0);
78+
ASSERT_FALSE(address.ok());
79+
ASSERT_NE(address.error(), Error::Ok);
80+
}
81+
{
82+
// ID too large; only two zero-indexed entries in the allocator.
83+
Result<void*> address = allocator.get_offset_address(
84+
/*memory_id=*/2, /*offset_bytes=*/0, /*size_bytes=*/2);
85+
ASSERT_FALSE(address.ok());
86+
ASSERT_NE(address.error(), Error::Ok);
87+
}
88+
}
89+
90+
// TODO(T162089316): Tests the deprecated API. Remove this when removing the
91+
// API.
92+
TEST_F(HierarchicalAllocatorTest, DEPRECATEDSmoke) {
3293
constexpr size_t n_allocators = 2;
3394
constexpr size_t size0 = 4;
3495
constexpr size_t size1 = 8;

runtime/executor/test/managed_memory_manager.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@ class ManagedMemoryManager {
2929
: const_allocator_(0, nullptr),
3030
non_const_pool_(new uint8_t[non_const_mem_bytes]),
3131
non_const_allocators_({
32-
MemoryAllocator(non_const_mem_bytes, non_const_pool_.get()),
32+
{non_const_pool_.get(), non_const_mem_bytes},
3333
}),
34-
non_const_allocator_(
34+
non_const_allocator_({
35+
non_const_allocators_.data(),
3536
non_const_allocators_.size(),
36-
non_const_allocators_.data()),
37+
}),
3738
runtime_pool_(new uint8_t[runtime_mem_bytes]),
3839
runtime_allocator_(runtime_mem_bytes, runtime_pool_.get()),
3940
temp_allocator_(0, nullptr),
@@ -51,7 +52,7 @@ class ManagedMemoryManager {
5152
MemoryAllocator const_allocator_;
5253

5354
std::unique_ptr<uint8_t[]> non_const_pool_;
54-
std::vector<MemoryAllocator> non_const_allocators_;
55+
std::vector<Span<uint8_t>> non_const_allocators_;
5556
torch::executor::HierarchicalAllocator non_const_allocator_;
5657

5758
std::unique_ptr<uint8_t[]> runtime_pool_;

0 commit comments

Comments
 (0)