Skip to content

Commit eb31945

Browse files
authored
Update run.cpp (#779)
surgery to simplify calls to ET/AOTI for pedagogical purposes
1 parent 5aee793 commit eb31945

File tree

1 file changed

+3
-12
lines changed
  • parking_lot/unsupported/runner

1 file changed

+3
-12
lines changed

parking_lot/unsupported/runner/run.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -151,31 +151,22 @@ float* forward(Transformer* transformer, int token, int pos) {
151151
torch::Tensor token_tensor = torch::from_blob(token_buffer, {1, 1}, torch::kLong);
152152
torch::Tensor pos_tensor = torch::from_blob(pos_buffer, {1}, torch::kLong);
153153
std::vector<torch::Tensor> inputs{token_tensor, pos_tensor};
154-
154+
// call AOTI model
155155
torch::Tensor result = transformer->runner->run(inputs)[0];
156156
auto logits = result[0].data_ptr();
157-
158157
#else // __ET_MODEL__
159158
ManagedTensor pos_managed(
160159
pos_buffer, sizeof(int64_t), { 1 }, ScalarType::Long);
161-
#ifndef __KV_CACHE__
162-
// @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
163-
ManagedTensor tokens_managed(&(s->toks[pos]), /*ignored*/sizeof(int64_t)*(pos+1), {1, 1}, ScalarType::Long);
164-
#else // __KV_CACHE__
165160
ManagedTensor tokens_managed(
166161
token_buffer, sizeof(int64_t), {1, 1}, ScalarType::Long);
167-
#endif
168162
std::vector<EValue> inputs;
169163
auto tmp1 = EValue(tokens_managed.get_aliasing_tensor());
170164
auto tmp2 = EValue(pos_managed.get_aliasing_tensor());
171-
172165
inputs.push_back(tmp1);
173166
inputs.push_back(tmp2);
167+
// call ET model
174168
Result<std::vector<EValue>> outputs_res = transformer->runner->forward(inputs);
175-
if (!outputs_res.ok()) {
176-
fprintf(stderr, "Executorch forward() failed.");
177-
exit(EXIT_FAILURE);
178-
}
169+
assert(outputs_res.ok());
179170
std::vector<EValue> result = outputs_res.get();
180171
auto logits = result[0].toTensor().const_data_ptr();
181172
#endif

0 commit comments

Comments
 (0)