Skip to content

Commit 504f298

Browse files
jackzhxngfacebook-github-bot
authored andcommitted
Introduce new DataLoader::load() with segment info
Summary: Adds a new DataLoader::load() function with an additonal parameter, SegmentInfo. Differential Revision: D59399538
1 parent f32d707 commit 504f298

File tree

3 files changed

+73
-3
lines changed

3 files changed

+73
-3
lines changed

runtime/core/data_loader.h

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <cstddef>
12+
#include <optional>
1213

1314
#include <executorch/runtime/core/freeable_buffer.h>
1415
#include <executorch/runtime/core/result.h>
@@ -17,6 +18,27 @@
1718
namespace torch {
1819
namespace executor {
1920

21+
enum class SegmentType { Program, Constant, Backend };
22+
23+
struct SegmentInfo {
24+
size_t segment_index = 0;
25+
SegmentType segment_type = SegmentType::Program;
26+
const char* descriptor =
27+
nullptr; // This can be any descriptor to pass to the data loader. At the
28+
// moment for backends this would be a backend id and for null
29+
// otherwise.
30+
31+
SegmentInfo() = default;
32+
33+
SegmentInfo(
34+
size_t segment_index,
35+
SegmentType segment_type,
36+
const char* descriptor)
37+
: segment_index(segment_index),
38+
segment_type(segment_type),
39+
descriptor(descriptor) {}
40+
};
41+
2042
/**
2143
* Loads from a data source.
2244
*
@@ -27,16 +49,33 @@ class DataLoader {
2749
virtual ~DataLoader() = default;
2850

2951
/**
52+
* DEPRECATED: Use `load()` going forward for access to segment info during
53+
* the load.
54+
*
3055
* Loads `size` bytes at byte offset `offset` from the underlying data source
3156
* into a `FreeableBuffer`, which owns the memory.
3257
*
3358
* NOTE: This must be thread-safe. If this call modifies common state, the
3459
* implementation must do its own locking.
3560
*/
36-
__ET_NODISCARD virtual Result<FreeableBuffer> Load(
61+
__ET_DEPRECATED __ET_NODISCARD virtual Result<FreeableBuffer> Load(
3762
size_t offset,
3863
size_t size) = 0;
3964

65+
/**
66+
* Loads `size` bytes at byte offset `offset` from the underlying data source
67+
* into a `FreeableBuffer`, which owns the memory.
68+
*
69+
* NOTE: This must be thread-safe. If this call modifies common state, the
70+
* implementation must do its own locking.
71+
*/
72+
__ET_NODISCARD virtual inline Result<FreeableBuffer> load(
73+
size_t _offset,
74+
size_t _size,
75+
const std::optional<SegmentInfo>& _segment_info) {
76+
return Error::NotImplemented;
77+
}
78+
4079
/**
4180
* Returns the length of the underlying data source, typically the file size.
4281
*/

runtime/executor/program.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#include <cstddef>
1212
#include <cstdint>
13+
#include <iostream>
14+
#include <optional>
1315

1416
#include <executorch/runtime/core/event_tracer_hooks.h>
1517
#include <executorch/runtime/executor/memory_manager.h>
@@ -423,8 +425,19 @@ Result<FreeableBuffer> Program::LoadSegment(size_t index) const {
423425
// Could fail if offset and size are out of bound for the data, or if this
424426
// is reading from a file and fails, or for many other reasons depending on
425427
// the implementation of the loader.
426-
return loader_->Load(
427-
segment_base_offset_ + segment->offset(), segment->size());
428+
// TODO(jackzhxng): deprecate Load() and use load() instead.
429+
Result<FreeableBuffer> segment_data = loader_->load(
430+
segment_base_offset_ + segment->offset(),
431+
segment->size(),
432+
std::nullopt /* segment_info */);
433+
if (segment_data.error() == Error::NotImplemented) {
434+
std::cerr << "Using the old Load()" << std::endl; // TODO: remove debug
435+
return loader_->Load(
436+
segment_base_offset_ + segment->offset(), segment->size());
437+
} else {
438+
std::cerr << "Using the new load()" << std::endl; // TODO: remove debug
439+
return segment_data;
440+
}
428441
}
429442

430443
} // namespace executor

runtime/executor/test/backend_integration_test.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,24 @@ class DataLoaderSpy : public DataLoader {
188188
context->buffer.data(), context->buffer.size(), FreeBuffer, context);
189189
}
190190

191+
// TODO: figure out how to test that the segment_info is being used and
192+
// passed in properly.
193+
Result<FreeableBuffer> load(
194+
size_t offset,
195+
size_t size,
196+
const std::optional<torch::executor::SegmentInfo>& segment_info)
197+
override {
198+
Result<FreeableBuffer> buf = delegate_->Load(offset, size);
199+
if (!buf.ok()) {
200+
return buf.error();
201+
}
202+
operations_.push_back({Operation::Load, offset, /*data=*/nullptr, size});
203+
auto* context = new SpyContext(&operations_, std::move(buf.get()));
204+
// Use context->buffer since buf has been moved.
205+
return FreeableBuffer(
206+
context->buffer.data(), context->buffer.size(), FreeBuffer, context);
207+
}
208+
191209
Result<size_t> size() const override {
192210
return delegate_->size();
193211
}

0 commit comments

Comments
 (0)