Skip to content

Commit 8ef6c79

Browse files
dbortfacebook-github-bot
authored andcommitted
Move examples/qualcomm out from under the torch namespace (#5400)
Summary: The code under examples/... is a proxy for user code, and users should never declare code under the `torch::` or `executorch::` namespaces. Move this code under the `example::` namespace to make it more clear that users should use their own namespaces when writing code like this. Pull Request resolved: #5400 Test Plan: - Built using the instructions at https://github.com/pytorch/executorch/blob/main/examples/qualcomm/README.md - test-llama-runner-qnn-linux CI job succeeded Reviewed By: shoumikhin Differential Revision: D62969111 Pulled By: dbort fbshipit-source-id: 9ec27528dd85f60d8c538d54ce6ddf621e63cf52
1 parent b89c52c commit 8ef6c79

File tree

13 files changed

+228
-181
lines changed

13 files changed

+228
-181
lines changed

examples/qualcomm/executor_runner/qnn_executor_runner.cpp

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,24 @@ DEFINE_int32(
7171
20000000, // 20MB
7272
"Size of the debug buffer in bytes to allocate for intermediate outputs and program outputs logging.");
7373

74-
using namespace torch::executor;
75-
using torch::executor::MemoryAllocator;
76-
using torch::executor::util::FileDataLoader;
74+
using executorch::aten::Tensor;
75+
using executorch::aten::TensorImpl;
76+
using executorch::etdump::ETDumpGen;
77+
using executorch::etdump::ETDumpResult;
78+
using executorch::extension::FileDataLoader;
79+
using executorch::extension::prepare_input_tensors;
80+
using executorch::runtime::Error;
81+
using executorch::runtime::EValue;
82+
using executorch::runtime::EventTracerDebugLogLevel;
83+
using executorch::runtime::HierarchicalAllocator;
84+
using executorch::runtime::MemoryAllocator;
85+
using executorch::runtime::MemoryManager;
86+
using executorch::runtime::Method;
87+
using executorch::runtime::MethodMeta;
88+
using executorch::runtime::Program;
89+
using executorch::runtime::Result;
90+
using executorch::runtime::Span;
91+
using executorch::runtime::TensorInfo;
7792

7893
class CustomMemory {
7994
public:
@@ -112,7 +127,7 @@ class CustomMemory {
112127
};
113128

114129
int main(int argc, char** argv) {
115-
runtime_init();
130+
executorch::runtime::runtime_init();
116131

117132
gflags::ParseCommandLineFlags(&argc, &argv, true);
118133
if (argc != 1) {
@@ -211,7 +226,7 @@ int main(int argc, char** argv) {
211226
// the method can mutate the memory-planned buffers, so the method should only
212227
// be used by a single thread at at time, but it can be reused.
213228
//
214-
torch::executor::ETDumpGen etdump_gen = torch::executor::ETDumpGen();
229+
ETDumpGen etdump_gen;
215230
Result<Method> method =
216231
program->load_method(method_name, &memory_manager, &etdump_gen);
217232
ET_CHECK_MSG(
@@ -261,7 +276,7 @@ int main(int argc, char** argv) {
261276
}
262277
for (int output_index = 0; output_index < method->outputs_size();
263278
++output_index) {
264-
const exec_aten::Tensor& t = method->get_output(output_index).toTensor();
279+
const Tensor& t = method->get_output(output_index).toTensor();
265280
out_custom_mem.push_back(
266281
std::make_unique<CustomMemory>(FLAGS_shared_buffer));
267282
std::unique_ptr<CustomMemory>& custom_mem_ptr = out_custom_mem.back();
@@ -415,7 +430,7 @@ int main(int argc, char** argv) {
415430
elapsed_time / inference_index);
416431
} else {
417432
// if no input is provided, fill the inputs with default values
418-
auto inputs = util::prepare_input_tensors(*method);
433+
auto inputs = prepare_input_tensors(*method);
419434
ET_CHECK_MSG(
420435
inputs.ok(),
421436
"Could not prepare inputs: 0x%" PRIx32,
@@ -434,7 +449,7 @@ int main(int argc, char** argv) {
434449

435450
// Dump the etdump data containing profiling/debugging data to the specified
436451
// file.
437-
etdump_result result = etdump_gen.get_etdump_data();
452+
ETDumpResult result = etdump_gen.get_etdump_data();
438453
if (result.buf != nullptr && result.size > 0) {
439454
ET_LOG(
440455
Info,
@@ -452,7 +467,7 @@ int main(int argc, char** argv) {
452467
Info,
453468
"Write debug output binary to %s, Size = %zu",
454469
FLAGS_debug_output_path.c_str(),
455-
FLAGS_debug_buffer_size);
470+
(size_t)FLAGS_debug_buffer_size);
456471
FILE* f = fopen(FLAGS_debug_output_path.c_str(), "w+");
457472
fwrite((uint8_t*)debug_buffer, 1, FLAGS_debug_buffer_size, f);
458473
fclose(f);

examples/qualcomm/oss_scripts/llama2/qnn_llama_runner.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
#include <fstream>
2424
#include <vector>
2525

26-
using torch::executor::MemoryAllocator;
27-
2826
DEFINE_string(
2927
model_path,
3028
"qnn_llama2.pte",
@@ -49,9 +47,12 @@ DEFINE_int32(
4947
128,
5048
"Total number of tokens to generate (prompt + output). Defaults to max_seq_len. If the number of input tokens + seq_len > max_seq_len, the output will be truncated to max_seq_len tokens.");
5149

52-
int main(int argc, char** argv) {
53-
using namespace torch::executor;
50+
using executorch::runtime::Error;
51+
using executorch::runtime::MemoryAllocator;
52+
using executorch::runtime::MethodMeta;
53+
using executorch::runtime::Result;
5454

55+
int main(int argc, char** argv) {
5556
gflags::ParseCommandLineFlags(&argc, &argv, true);
5657

5758
const char* tokenizer_path = FLAGS_tokenizer_path.c_str();
@@ -60,7 +61,7 @@ int main(int argc, char** argv) {
6061
int32_t seq_len = FLAGS_seq_len;
6162

6263
// create llama runner
63-
Runner runner(FLAGS_model_path, tokenizer_path, temperature);
64+
example::Runner runner(FLAGS_model_path, tokenizer_path, temperature);
6465
ET_CHECK_MSG(runner.load() == Error::Ok, "Runner failed to load method");
6566

6667
// MethodMeta describes the memory requirements of the method.

examples/qualcomm/oss_scripts/llama2/runner/runner.cpp

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,27 @@
2222
#include <memory>
2323
#include <sstream>
2424

25-
namespace torch {
26-
namespace executor {
25+
using executorch::aten::ScalarType;
26+
using executorch::aten::SizesType;
27+
using executorch::aten::Tensor;
28+
using executorch::extension::from_blob;
29+
using executorch::extension::Module;
30+
using executorch::extension::TensorPtr;
31+
using executorch::extension::llm::BPETokenizer;
32+
using executorch::extension::llm::Sampler;
33+
using executorch::extension::llm::time_in_ms;
34+
using executorch::runtime::Error;
35+
using executorch::runtime::EValue;
36+
using executorch::runtime::MethodMeta;
37+
using executorch::runtime::Result;
38+
using executorch::runtime::TensorInfo;
39+
40+
// TODO: Remove this usage of an internal-only function.
41+
using executorch::runtime::internal::set_tensor_data;
42+
43+
namespace example {
2744

2845
namespace {
29-
using namespace executorch::extension;
3046
static constexpr auto kTopp = 0.9f;
3147
void printReport(const Runner::Stats& stats);
3248
std::string statsToJsonString(const Runner::Stats& stats);
@@ -57,7 +73,7 @@ Error Runner::load() {
5773
if (is_loaded()) {
5874
return Error::Ok;
5975
}
60-
stats_.model_load_start_ms = util::time_in_ms();
76+
stats_.model_load_start_ms = time_in_ms();
6177
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method("forward"));
6278

6379
// Read out metadata from the model
@@ -97,7 +113,7 @@ Error Runner::load() {
97113
temperature_,
98114
kTopp,
99115
static_cast<unsigned long long>(std::time(nullptr)));
100-
stats_.model_load_end_ms = util::time_in_ms();
116+
stats_.model_load_end_ms = time_in_ms();
101117

102118
return Error::Ok;
103119
}
@@ -125,7 +141,7 @@ T Runner::getMetadataHelper(std::string method_name, T default_val) {
125141
}
126142

127143
template <typename T>
128-
int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) {
144+
int32_t Runner::logitsToToken(const Tensor& logits_tensor) {
129145
T* logits = logits_tensor.mutable_data_ptr<T>();
130146

131147
// Since the logits are for all tokens, get the last token probabilities
@@ -135,7 +151,7 @@ int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) {
135151

136152
// Given an input token. Set up the inputs for the model and execute a single
137153
// step. Returning the logits tensor.
138-
Result<exec_aten::Tensor> Runner::run_model_step(
154+
Result<Tensor> Runner::run_model_step(
139155
int64_t input_token,
140156
TensorPtr& token,
141157
TensorPtr& start_pos,
@@ -167,7 +183,7 @@ Result<exec_aten::Tensor> Runner::run_model_step(
167183
char* new_inp_addr = io_mem_mgr_.update_k_caches_read(j, el_size);
168184
// inputs
169185
ET_CHECK_MSG(
170-
internal::set_tensor_data(
186+
set_tensor_data(
171187
*kv_tensors[j], new_inp_addr, kv_tensors[j]->nbytes()) == Error::Ok,
172188
"Failed to set input tensor when updating k_cache");
173189
}
@@ -177,13 +193,13 @@ Result<exec_aten::Tensor> Runner::run_model_step(
177193
char* new_inp_addr = io_mem_mgr_.update_v_caches_read(v_idx, v_offset);
178194

179195
ET_CHECK_MSG(
180-
internal::set_tensor_data(
196+
set_tensor_data(
181197
*kv_tensors[j], new_inp_addr, kv_tensors[j]->nbytes()) == Error::Ok,
182198
"Failed to set input tensor when updating v_cache");
183199
// outputs
184200
char* new_out_addr = io_mem_mgr_.update_v_caches_write(v_idx, v_offset);
185201
ET_CHECK_MSG(
186-
internal::set_tensor_data(
202+
set_tensor_data(
187203
*kv_outputs[j], new_out_addr, kv_outputs[j]->nbytes()) == Error::Ok,
188204
"Failed to set output tensor when updating v_cache");
189205
ET_CHECK_MSG(
@@ -210,7 +226,7 @@ Error Runner::generate(
210226

211227
// First token time only measures the time it takes to encode the prompt and
212228
// return a response token.
213-
stats_.inference_start_ms = util::time_in_ms();
229+
stats_.inference_start_ms = time_in_ms();
214230
shouldStop_ = false;
215231

216232
// Set the sequence length to the max seq length if not provided
@@ -235,21 +251,21 @@ Error Runner::generate(
235251
"Sequence length exceeded - please increase the seq_len value passed to generate()");
236252

237253
int32_t pos = 0, prev_token, cur_token = prompt_tokens[0];
238-
std::vector<exec_aten::SizesType> token_shape = {1, 1};
254+
std::vector<SizesType> token_shape = {1, 1};
239255

240256
io_mem_mgr_.get_input_token_ptr()[0] = 0;
241-
std::vector<exec_aten::SizesType> start_pos_shape = {1, 1};
257+
std::vector<SizesType> start_pos_shape = {1, 1};
242258

243259
float* atten_mask_ptr =
244260
reinterpret_cast<float*>(io_mem_mgr_.get_atten_mask_ptr());
245261
std::fill(atten_mask_ptr, atten_mask_ptr + max_seq_len_, -255);
246262
atten_mask_ptr[max_seq_len_ - 1] = 0;
247263

248-
std::vector<exec_aten::SizesType> atten_mask_shape = {1, max_seq_len_};
264+
std::vector<SizesType> atten_mask_shape = {1, max_seq_len_};
249265

250-
std::vector<exec_aten::SizesType> logits_data_shape = {1, vocab_size_};
266+
std::vector<SizesType> logits_data_shape = {1, vocab_size_};
251267

252-
std::vector<exec_aten::SizesType> hidden_states_data_shape = {1, 1, dim_};
268+
std::vector<SizesType> hidden_states_data_shape = {1, 1, dim_};
253269

254270
// initialize tensor wrappers
255271
auto token = from_blob(
@@ -274,7 +290,7 @@ Error Runner::generate(
274290
method_meta->input_tensor_meta(input_index);
275291

276292
auto tensor_shape = tensor_meta->sizes();
277-
std::vector<exec_aten::SizesType> sizes(
293+
std::vector<SizesType> sizes(
278294
tensor_shape.data(), tensor_shape.data() + tensor_shape.size());
279295
kv_tensors.emplace_back(from_blob(
280296
io_mem_mgr_.get_k_caches_read_ptr(i),
@@ -284,7 +300,7 @@ Error Runner::generate(
284300
// outpus
285301
Result<TensorInfo> out_tensor_meta = method_meta->output_tensor_meta(i + 1);
286302
tensor_shape = out_tensor_meta->sizes();
287-
sizes = std::vector<exec_aten::SizesType>{
303+
sizes = std::vector<SizesType>{
288304
tensor_shape.data(), tensor_shape.data() + tensor_shape.size()};
289305
kv_outputs.emplace_back(from_blob(
290306
io_mem_mgr_.get_k_caches_write_ptr(i),
@@ -303,7 +319,7 @@ Error Runner::generate(
303319
Result<TensorInfo> tensor_meta =
304320
method_meta->input_tensor_meta(input_index);
305321
auto tensor_shape = tensor_meta->sizes();
306-
std::vector<exec_aten::SizesType> sizes(
322+
std::vector<SizesType> sizes(
307323
tensor_shape.data(), tensor_shape.data() + tensor_shape.size());
308324

309325
kv_tensors.emplace_back(from_blob(
@@ -315,7 +331,7 @@ Error Runner::generate(
315331
Result<TensorInfo> out_tensor_meta =
316332
method_meta->output_tensor_meta(output_index);
317333
tensor_shape = out_tensor_meta->sizes();
318-
sizes = std::vector<exec_aten::SizesType>{
334+
sizes = std::vector<SizesType>{
319335
tensor_shape.data(), tensor_shape.data() + tensor_shape.size()};
320336

321337
kv_outputs.push_back(from_blob(
@@ -342,19 +358,18 @@ Error Runner::generate(
342358
auto logits_res = run_model_step(
343359
cur_token, token, start_pos, atten_mask, kv_tensors, kv_outputs);
344360
if (pos == num_prompt_tokens) {
345-
stats_.first_token_ms = util::time_in_ms();
361+
stats_.first_token_ms = time_in_ms();
346362
} else if (pos == num_prompt_tokens - 1) {
347-
stats_.prompt_eval_end_ms = util::time_in_ms();
363+
stats_.prompt_eval_end_ms = time_in_ms();
348364
}
349365

350366
ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
351-
exec_aten::Tensor& logits_tensor = logits_res.get();
367+
Tensor& logits_tensor = logits_res.get();
352368
prev_token = cur_token;
353-
long sample_start_time_ms = util::time_in_ms();
369+
long sample_start_time_ms = time_in_ms();
354370

355371
cur_token = logitsToToken<float>(logits_tensor);
356-
stats_.aggregate_sampling_time_ms +=
357-
util::time_in_ms() - sample_start_time_ms;
372+
stats_.aggregate_sampling_time_ms += time_in_ms() - sample_start_time_ms;
358373

359374
// advance the state machine
360375
if (pos < num_prompt_tokens - 1) {
@@ -381,7 +396,7 @@ Error Runner::generate(
381396
break;
382397
}
383398
}
384-
stats_.inference_end_ms = util::time_in_ms();
399+
stats_.inference_end_ms = time_in_ms();
385400

386401
if (pos == seq_len) {
387402
ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len);
@@ -650,5 +665,4 @@ template bool Runner::getMetadataHelper<bool>(
650665
std::string method_name,
651666
bool default_val);
652667

653-
} // namespace executor
654-
} // namespace torch
668+
} // namespace example

0 commit comments

Comments
 (0)