@@ -151,31 +151,22 @@ float* forward(Transformer* transformer, int token, int pos) {
151
151
torch::Tensor token_tensor = torch::from_blob (token_buffer, {1 , 1 }, torch::kLong );
152
152
torch::Tensor pos_tensor = torch::from_blob (pos_buffer, {1 }, torch::kLong );
153
153
std::vector<torch::Tensor> inputs{token_tensor, pos_tensor};
154
-
154
+ // call AOTI model
155
155
torch::Tensor result = transformer->runner ->run (inputs)[0 ];
156
156
auto logits = result[0 ].data_ptr ();
157
-
158
157
#else // __ET_MODEL__
159
158
ManagedTensor pos_managed (
160
159
pos_buffer, sizeof (int64_t ), { 1 }, ScalarType::Long);
161
- #ifndef __KV_CACHE__
162
- // @lint-ignore CLANGTIDY facebook-hte-LocalUncheckedArrayBounds
163
- ManagedTensor tokens_managed (&(s->toks [pos]), /* ignored*/ sizeof (int64_t )*(pos+1 ), {1 , 1 }, ScalarType::Long);
164
- #else // __KV_CACHE__
165
160
ManagedTensor tokens_managed (
166
161
token_buffer, sizeof (int64_t ), {1 , 1 }, ScalarType::Long);
167
- #endif
168
162
std::vector<EValue> inputs;
169
163
auto tmp1 = EValue (tokens_managed.get_aliasing_tensor ());
170
164
auto tmp2 = EValue (pos_managed.get_aliasing_tensor ());
171
-
172
165
inputs.push_back (tmp1);
173
166
inputs.push_back (tmp2);
167
+ // call ET model
174
168
Result<std::vector<EValue>> outputs_res = transformer->runner ->forward (inputs);
175
- if (!outputs_res.ok ()) {
176
- fprintf (stderr, " Executorch forward() failed." );
177
- exit (EXIT_FAILURE);
178
- }
169
+ assert (outputs_res.ok ());
179
170
std::vector<EValue> result = outputs_res.get ();
180
171
auto logits = result[0 ].toTensor ().const_data_ptr ();
181
172
#endif
0 commit comments