@@ -178,34 +178,36 @@ runtime::Result<std::unordered_set<std::string>> Module::method_names() {
178
178
179
179
runtime::Error Module::load_method (
180
180
const std::string& method_name,
181
+ runtime::HierarchicalAllocator* planned_memory,
181
182
torch::executor::EventTracer* event_tracer) {
182
183
if (!is_method_loaded (method_name)) {
183
184
ET_CHECK_OK_OR_RETURN_ERROR (load ());
184
185
185
186
MethodHolder method_holder;
186
187
187
- const auto method_metadata =
188
- ET_UNWRAP (program_->method_meta (method_name.c_str ()));
189
- const auto planned_buffersCount =
190
- method_metadata.num_memory_planned_buffers ();
191
- method_holder.planned_buffers .reserve (planned_buffersCount);
192
- method_holder.planned_spans .reserve (planned_buffersCount);
188
+ if (!planned_memory) {
189
+ const auto method_metadata =
190
+ ET_UNWRAP (program_->method_meta (method_name.c_str ()));
191
+ const auto planned_buffers_count =
192
+ method_metadata.num_memory_planned_buffers ();
193
+ method_holder.planned_buffers .reserve (planned_buffers_count);
194
+ method_holder.planned_spans .reserve (planned_buffers_count);
193
195
194
- for (auto index = 0 ; index < planned_buffersCount; ++index) {
195
- const auto buffer_size =
196
- method_metadata.memory_planned_buffer_size (index).get ();
197
- method_holder.planned_buffers .emplace_back (buffer_size);
198
- method_holder.planned_spans .emplace_back (
199
- method_holder.planned_buffers .back ().data (), buffer_size);
196
+ for (auto index = 0 ; index < planned_buffers_count; ++index) {
197
+ const auto buffer_size =
198
+ method_metadata.memory_planned_buffer_size (index).get ();
199
+ method_holder.planned_buffers .emplace_back (buffer_size);
200
+ method_holder.planned_spans .emplace_back (
201
+ method_holder.planned_buffers .back ().data (), buffer_size);
202
+ }
203
+ method_holder.planned_memory =
204
+ std::make_unique<runtime::HierarchicalAllocator>(runtime::Span (
205
+ method_holder.planned_spans .data (),
206
+ method_holder.planned_spans .size ()));
207
+ planned_memory = method_holder.planned_memory .get ();
200
208
}
201
- method_holder.planned_memory =
202
- std::make_unique<runtime::HierarchicalAllocator>(runtime::Span (
203
- method_holder.planned_spans .data (),
204
- method_holder.planned_spans .size ()));
205
209
method_holder.memory_manager = std::make_unique<runtime::MemoryManager>(
206
- memory_allocator_.get (),
207
- method_holder.planned_memory .get (),
208
- temp_allocator_.get ());
210
+ memory_allocator_.get (), planned_memory, temp_allocator_.get ());
209
211
method_holder.method = ET_UNWRAP_UNIQUE (program_->load_method (
210
212
method_name.c_str (),
211
213
method_holder.memory_manager .get (),
0 commit comments