Skip to content

Commit 75898bf

Browse files
Enhance load_method to support optional planned memory allocator (#8032)
* Enhance load_method to support optional planned memory allocator - Updated the load_method signature to accept an optional runtime::HierarchicalAllocator parameter. * Cleaned up methods, deprecated old interfaces. * Fixed linter errors. --------- Co-authored-by: Jacob Szwejbka <[email protected]>
1 parent d71f54a commit 75898bf

File tree

2 files changed

+39
-20
lines changed

2 files changed

+39
-20
lines changed

extension/module/module.cpp

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -178,34 +178,36 @@ runtime::Result<std::unordered_set<std::string>> Module::method_names() {
178178

179179
runtime::Error Module::load_method(
180180
const std::string& method_name,
181+
runtime::HierarchicalAllocator* planned_memory,
181182
torch::executor::EventTracer* event_tracer) {
182183
if (!is_method_loaded(method_name)) {
183184
ET_CHECK_OK_OR_RETURN_ERROR(load());
184185

185186
MethodHolder method_holder;
186187

187-
const auto method_metadata =
188-
ET_UNWRAP(program_->method_meta(method_name.c_str()));
189-
const auto planned_buffersCount =
190-
method_metadata.num_memory_planned_buffers();
191-
method_holder.planned_buffers.reserve(planned_buffersCount);
192-
method_holder.planned_spans.reserve(planned_buffersCount);
188+
if (!planned_memory) {
189+
const auto method_metadata =
190+
ET_UNWRAP(program_->method_meta(method_name.c_str()));
191+
const auto planned_buffers_count =
192+
method_metadata.num_memory_planned_buffers();
193+
method_holder.planned_buffers.reserve(planned_buffers_count);
194+
method_holder.planned_spans.reserve(planned_buffers_count);
193195

194-
for (auto index = 0; index < planned_buffersCount; ++index) {
195-
const auto buffer_size =
196-
method_metadata.memory_planned_buffer_size(index).get();
197-
method_holder.planned_buffers.emplace_back(buffer_size);
198-
method_holder.planned_spans.emplace_back(
199-
method_holder.planned_buffers.back().data(), buffer_size);
196+
for (auto index = 0; index < planned_buffers_count; ++index) {
197+
const auto buffer_size =
198+
method_metadata.memory_planned_buffer_size(index).get();
199+
method_holder.planned_buffers.emplace_back(buffer_size);
200+
method_holder.planned_spans.emplace_back(
201+
method_holder.planned_buffers.back().data(), buffer_size);
202+
}
203+
method_holder.planned_memory =
204+
std::make_unique<runtime::HierarchicalAllocator>(runtime::Span(
205+
method_holder.planned_spans.data(),
206+
method_holder.planned_spans.size()));
207+
planned_memory = method_holder.planned_memory.get();
200208
}
201-
method_holder.planned_memory =
202-
std::make_unique<runtime::HierarchicalAllocator>(runtime::Span(
203-
method_holder.planned_spans.data(),
204-
method_holder.planned_spans.size()));
205209
method_holder.memory_manager = std::make_unique<runtime::MemoryManager>(
206-
memory_allocator_.get(),
207-
method_holder.planned_memory.get(),
208-
temp_allocator_.get());
210+
memory_allocator_.get(), planned_memory, temp_allocator_.get());
209211
method_holder.method = ET_UNWRAP_UNIQUE(program_->load_method(
210212
method_name.c_str(),
211213
method_holder.memory_manager.get(),

extension/module/module.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ class Module {
152152
* needed. The loaded method is cached to reuse the next time it's executed.
153153
*
154154
* @param[in] method_name The name of the method to load.
155+
* @param[in] planned_memory The memory-planned buffers to use for mutable
156+
* tensor data when executing a method.
155157
* @param[in] event_tracer Per-method event tracer to profile/trace methods
156158
* individually. When not given, the event tracer passed to the Module
157159
* constructor is used. Otherwise, this per-method event tracer takes
@@ -162,20 +164,35 @@ class Module {
162164
ET_NODISCARD
163165
runtime::Error load_method(
164166
const std::string& method_name,
167+
runtime::HierarchicalAllocator* planned_memory = nullptr,
165168
torch::executor::EventTracer* event_tracer = nullptr);
166169

170+
ET_DEPRECATED ET_NODISCARD runtime::Error inline load_method(
171+
const std::string& method_name,
172+
torch::executor::EventTracer* event_tracer) {
173+
return load_method(method_name, nullptr, event_tracer);
174+
}
175+
167176
/**
168177
* Load the 'forward' method from the program and set up memory management if
169178
* needed. The loaded method is cached to reuse the next time it's executed.
170179
*
180+
* @param[in] planned_memory The memory-planned buffers to use for mutable
181+
* tensor data when executing the 'forward' method.
171182
* @param[in] event_tracer An event tracer used for tracking and logging
172183
* events.
173184
*
174185
* @returns An Error to indicate success or failure.
175186
*/
176187
ET_NODISCARD inline runtime::Error load_forward(
188+
runtime::HierarchicalAllocator* planned_memory = nullptr,
177189
torch::executor::EventTracer* event_tracer = nullptr) {
178-
return load_method("forward", event_tracer);
190+
return load_method("forward", planned_memory, event_tracer);
191+
}
192+
193+
ET_DEPRECATED ET_NODISCARD inline runtime::Error load_forward(
194+
torch::executor::EventTracer* event_tracer) {
195+
return load_forward(nullptr, event_tracer);
179196
}
180197

181198
/**

0 commit comments

Comments
 (0)