Skip to content

Commit 48c2f9d

Browse files
committed
Update ManagedTensor to reflect the new API
1 parent 67b1d31 commit 48c2f9d

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

runner/run.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,9 @@ float* forward(Transformer* transformer, int token, int pos) {
203203
.to(torch::kCPU);
204204
auto logits = result[0].data_ptr();
205205
#else // __ET_MODEL__
206-
ManagedTensor pos_managed(pos_buffer, sizeof(int64_t), {1}, ScalarType::Long);
206+
ManagedTensor pos_managed(pos_buffer, {1}, ScalarType::Long);
207207
ManagedTensor tokens_managed(
208-
token_buffer, sizeof(int64_t), {1, 1}, ScalarType::Long);
208+
token_buffer, {1, 1}, ScalarType::Long);
209209
std::vector<EValue> inputs;
210210
auto tmp1 = EValue(tokens_managed.get_aliasing_tensor());
211211
auto tmp2 = EValue(pos_managed.get_aliasing_tensor());

0 commit comments

Comments
 (0)