Skip to content

Commit 50e94c7

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Let Module use FIleDataLoader when requested. (#4174)
Summary: Pull Request resolved: #4174 Hide more data loader options behind the facade. Reviewed By: kirklandsign Differential Revision: D59498348
1 parent 561c035 commit 50e94c7

File tree

9 files changed

+54
-47
lines changed

9 files changed

+54
-47
lines changed

docs/source/extension-module.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ Use [ExecuTorch Dump](sdk-etdump.md) to trace model execution. Create an instanc
136136

137137
using namespace ::torch::executor;
138138

139-
Module module("/path/to/model.pte", Module::MlockConfig::UseMlock, std::make_unique<ETDumpGen>());
139+
Module module("/path/to/model.pte", Module::LoadMode::MmapUseMlock, std::make_unique<ETDumpGen>());
140140

141141
// Execute a method, e.g. module.forward(...); or module.execute("my_method", ...);
142142

docs/source/llm/getting-started.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,8 @@ penalties for repeated tokens, and biases to prioritize or de-prioritize specifi
313313
```cpp
314314
// main.cpp
315315

316+
using namespace torch::executor;
317+
316318
int main() {
317319
// Set up the prompt. This provides the seed text for the model to elaborate.
318320
std::cout << "Enter model prompt: ";
@@ -327,7 +329,7 @@ int main() {
327329
BasicSampler sampler = BasicSampler();
328330

329331
// Load the exported nanoGPT program, which was generated via the previous steps.
330-
Module model("nanogpt.pte", torch::executor::Module::MlockConfig::UseMlockIgnoreErrors);
332+
Module model("nanogpt.pte", Module::LoadMode::MmapUseMlockIgnoreErrors);
331333

332334
const auto max_input_tokens = 1024;
333335
const auto max_output_tokens = 30;
@@ -787,15 +789,14 @@ Include the ETDump header in your code.
787789

788790
Create an Instance of the ETDumpGen class and pass it to the Module constructor.
789791
```cpp
790-
std::unique_ptr<torch::executor::ETDumpGen> etdump_gen_ = std::make_unique<torch::executor::ETDumpGen>();
791-
Module model("nanogpt.pte", torch::executor::Module::MlockConfig::UseMlockIgnoreErrors, std::move(etdump_gen_));
792+
std::unique_ptr<ETDumpGen> etdump_gen_ = std::make_unique<ETDumpGen>();
793+
Module model("nanogpt.pte", Module::LoadMode::MmapUseMlockIgnoreErrors, std::move(etdump_gen_));
792794
```
793795
794796
After calling `generate()`, save the ETDump to a file. You can capture multiple
795797
model runs in a single trace, if desired.
796798
```cpp
797-
torch::executor::ETDumpGen* etdump_gen =
798-
static_cast<torch::executor::ETDumpGen*>(model.event_tracer());
799+
ETDumpGen* etdump_gen = static_cast<ETDumpGen*>(model.event_tracer());
799800
800801
ET_LOG(Info, "ETDump size: %zu blocks", etdump_gen->get_num_blocks());
801802
etdump_result result = etdump_gen->get_etdump_data();

examples/llm_manual/main.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,7 @@ int main() {
110110

111111
// Load the exported nanoGPT program, which was generated via the previous
112112
// steps.
113-
Module model(
114-
"nanogpt.pte",
115-
torch::executor::Module::MlockConfig::UseMlockIgnoreErrors);
113+
Module model("nanogpt.pte", Module::LoadMode::MmapUseMlockIgnoreErrors);
116114

117115
const auto max_input_tokens = 1024;
118116
const auto max_output_tokens = 30;

examples/models/phi-3-mini/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ int main() {
8383

8484
SentencePieceTokenizer tokenizer("tokenizer.model");
8585

86-
Module model("phi-3-mini.pte", Module::MlockConfig::UseMlockIgnoreErrors);
86+
Module model("phi-3-mini.pte", Module::LoadMode::MmapUseMlockIgnoreErrors);
8787

8888
const auto max_output_tokens = 128;
8989
generate(model, prompt, tokenizer, max_output_tokens);

examples/qualcomm/llama2/runner/runner.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Runner::Runner(
3838
const float temperature)
3939
: module_(std::make_unique<Module>(
4040
model_path,
41-
Module::MlockConfig::UseMlockIgnoreErrors)),
41+
Module::LoadMode::MmapUseMlockIgnoreErrors)),
4242
tokenizer_path_(tokenizer_path),
4343
model_path_(model_path),
4444
temperature_(temperature) {
@@ -649,7 +649,7 @@ Error Runner::mem_alloc(size_t alignment, size_t seq_len) {
649649
// Reset and re-init again to trigger registered function
650650
module_.reset();
651651
module_ = std::make_unique<Module>(
652-
model_path_, Module::MlockConfig::UseMlockIgnoreErrors),
652+
model_path_, Module::LoadMode::MmapUseMlockIgnoreErrors),
653653
ET_CHECK_MSG(load() == Error::Ok, "Runner failed to load method");
654654

655655
return Error::Ok;

extension/android/jni/jni_layer.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ class JEValue : public facebook::jni::JavaClass<JEValue> {
233233
class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
234234
private:
235235
friend HybridBase;
236-
std::unique_ptr<torch::executor::Module> module_;
236+
std::unique_ptr<Module> module_;
237237

238238
public:
239239
constexpr static auto kJavaDescriptor = "Lorg/pytorch/executorch/NativePeer;";
@@ -252,9 +252,8 @@ class ExecuTorchJni : public facebook::jni::HybridClass<ExecuTorchJni> {
252252
facebook::jni::alias_ref<
253253
facebook::jni::JMap<facebook::jni::JString, facebook::jni::JString>>
254254
extraFiles) {
255-
module_ = std::make_unique<torch::executor::Module>(
256-
modelPath->toStdString(),
257-
torch::executor::Module::MlockConfig::NoMlock);
255+
module_ = std::make_unique<Module>(
256+
modelPath->toStdString(), Module::LoadMode::Mmap);
258257
}
259258

260259
facebook::jni::local_ref<facebook::jni::JArrayClass<JEValue>> forward(

extension/module/module.cpp

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/extension/module/module.h>
1010

11+
#include <executorch/extension/data_loader/file_data_loader.h>
1112
#include <executorch/extension/data_loader/mmap_data_loader.h>
1213
#include <executorch/extension/memory_allocator/malloc_memory_allocator.h>
1314
#include <executorch/runtime/platform/runtime.h>
@@ -36,10 +37,10 @@ namespace torch::executor {
3637

3738
Module::Module(
3839
const std::string& file_path,
39-
const Module::MlockConfig mlock_config,
40+
const Module::LoadMode load_mode,
4041
std::unique_ptr<EventTracer> event_tracer)
4142
: file_path_(file_path),
42-
mlock_config_(mlock_config),
43+
load_mode_(load_mode),
4344
memory_allocator_(std::make_unique<util::MallocMemoryAllocator>()),
4445
temp_allocator_(std::make_unique<util::MallocMemoryAllocator>()),
4546
event_tracer_(std::move(event_tracer)) {
@@ -49,36 +50,41 @@ Module::Module(
4950
Module::Module(
5051
std::unique_ptr<DataLoader> data_loader,
5152
std::unique_ptr<MemoryAllocator> memory_allocator,
52-
std::unique_ptr<MemoryAllocator> tmp_memory_allocator,
53+
std::unique_ptr<MemoryAllocator> temp_allocator,
5354
std::unique_ptr<EventTracer> event_tracer)
5455
: data_loader_(std::move(data_loader)),
5556
memory_allocator_(
5657
memory_allocator ? std::move(memory_allocator)
5758
: std::make_unique<util::MallocMemoryAllocator>()),
5859
temp_allocator_(
59-
60-
tmp_memory_allocator
61-
? std::move(tmp_memory_allocator)
62-
: std::make_unique<util::MallocMemoryAllocator>()),
60+
temp_allocator ? std::move(temp_allocator)
61+
: std::make_unique<util::MallocMemoryAllocator>()),
6362
event_tracer_(std::move(event_tracer)) {
6463
runtime_init();
6564
}
6665

6766
Error Module::load(const Program::Verification verification) {
6867
if (!is_loaded()) {
6968
if (!data_loader_) {
70-
data_loader_ = ET_UNWRAP_UNIQUE(
71-
util::MmapDataLoader::from(file_path_.c_str(), [this] {
72-
switch (mlock_config_) {
73-
case MlockConfig::NoMlock:
74-
return util::MmapDataLoader::MlockConfig::NoMlock;
75-
case MlockConfig::UseMlock:
76-
return util::MmapDataLoader::MlockConfig::UseMlock;
77-
case MlockConfig::UseMlockIgnoreErrors:
78-
return util::MmapDataLoader::MlockConfig::UseMlockIgnoreErrors;
79-
}
80-
ET_ASSERT_UNREACHABLE();
81-
}()));
69+
switch (load_mode_) {
70+
case LoadMode::File:
71+
data_loader_ =
72+
ET_UNWRAP_UNIQUE(util::FileDataLoader::from(file_path_.c_str()));
73+
break;
74+
case LoadMode::Mmap:
75+
data_loader_ =
76+
ET_UNWRAP_UNIQUE(util::MmapDataLoader::from(file_path_.c_str()));
77+
break;
78+
case LoadMode::MmapUseMlock:
79+
data_loader_ = ET_UNWRAP_UNIQUE(util::MmapDataLoader::from(
80+
file_path_.c_str(), util::MmapDataLoader::MlockConfig::NoMlock));
81+
break;
82+
case LoadMode::MmapUseMlockIgnoreErrors:
83+
data_loader_ = ET_UNWRAP_UNIQUE(util::MmapDataLoader::from(
84+
file_path_.c_str(),
85+
util::MmapDataLoader::MlockConfig::UseMlockIgnoreErrors));
86+
break;
87+
}
8288
};
8389
program_ =
8490
ET_UNWRAP_UNIQUE(Program::load(data_loader_.get(), verification));

extension/module/module.h

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,42 +24,44 @@ namespace torch::executor {
2424
class Module final {
2525
public:
2626
/**
27-
* Enum to define memory locking behavior.
27+
* Enum to define loading behavior.
2828
*/
29-
enum class MlockConfig {
30-
/// Do not use memory locking.
31-
NoMlock,
29+
enum class LoadMode {
30+
/// Load the whole file as a buffer.
31+
File,
32+
/// Use mmap to load pages into memory.
33+
Mmap,
3234
/// Use memory locking and handle errors.
33-
UseMlock,
35+
MmapUseMlock,
3436
/// Use memory locking and ignore errors.
35-
UseMlockIgnoreErrors,
37+
MmapUseMlockIgnoreErrors,
3638
};
3739

3840
/**
3941
* Constructs an instance by loading a program from a file with specified
4042
* memory locking behavior.
4143
*
4244
* @param[in] file_path The path to the ExecuTorch program file to load.
43-
* @param[in] mlock_config The memory locking configuration to use.
45+
* @param[in] load_mode The loading mode to use.
4446
*/
4547
explicit Module(
4648
const std::string& file_path,
47-
const MlockConfig mlock_config = MlockConfig::UseMlock,
49+
const LoadMode load_mode = LoadMode::MmapUseMlock,
4850
std::unique_ptr<EventTracer> event_tracer = nullptr);
4951

5052
/**
5153
* Constructs an instance with the provided data loader and memory allocator.
5254
*
5355
* @param[in] data_loader A DataLoader used for loading program data.
5456
* @param[in] memory_allocator A MemoryAllocator used for memory management.
55-
* @param[in] tmp_memory_allocator A MemoryAllocator used for allocating
56-
* memory during execution time.
57+
* @param[in] temp_allocator A MemoryAllocator to use when allocating
58+
* temporary data during kernel or delegate execution.
5759
* @param[in] event_tracer A EventTracer used for tracking and logging events.
5860
*/
5961
explicit Module(
6062
std::unique_ptr<DataLoader> data_loader,
6163
std::unique_ptr<MemoryAllocator> memory_allocator = nullptr,
62-
std::unique_ptr<MemoryAllocator> tmp_memory_allocator = nullptr,
64+
std::unique_ptr<MemoryAllocator> temp_allocator = nullptr,
6365
std::unique_ptr<EventTracer> event_tracer = nullptr);
6466
Module(const Module&) = delete;
6567
Module& operator=(const Module&) = delete;
@@ -215,7 +217,7 @@ class Module final {
215217

216218
private:
217219
std::string file_path_;
218-
MlockConfig mlock_config_{MlockConfig::NoMlock};
220+
LoadMode load_mode_{LoadMode::MmapUseMlock};
219221
std::unique_ptr<DataLoader> data_loader_;
220222
std::unique_ptr<MemoryAllocator> memory_allocator_;
221223
std::unique_ptr<MemoryAllocator> temp_allocator_;

extension/module/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def define_common_targets():
2323
],
2424
deps = [
2525
"//executorch/extension/memory_allocator:malloc_memory_allocator",
26+
"//executorch/extension/data_loader:file_data_loader",
2627
"//executorch/extension/data_loader:mmap_data_loader",
2728
],
2829
exported_deps = [

0 commit comments

Comments
 (0)