|
12 | 12 |
|
13 | 13 | #include <executorch/runtime/core/exec_aten/exec_aten.h>
|
14 | 14 | #include <executorch/runtime/executor/method.h>
|
| 15 | +#include <executorch/runtime/executor/method_meta.h> |
15 | 16 | #include <executorch/runtime/platform/log.h>
|
16 | 17 | #ifdef USE_ATEN_LIB
|
17 | 18 | #include <ATen/ATen.h> // @manual=//caffe2/aten:ATen-core
|
@@ -61,51 +62,77 @@ inline void FillOnes(Tensor tensor) {
|
61 | 62 | * @returns An array of pointers that must be passed to `FreeInputs()` after
|
62 | 63 | * the Method is no longer needed.
|
63 | 64 | */
|
64 |
| -inline exec_aten::ArrayRef<void*> PrepareInputTensors(const Method& method) { |
| 65 | +inline exec_aten::ArrayRef<void*> PrepareInputTensors(Method& method) { |
| 66 | + auto method_meta = method.method_meta(); |
65 | 67 | size_t input_size = method.inputs_size();
|
66 | 68 | size_t num_allocated = 0;
|
67 | 69 | void** inputs = (void**)malloc(input_size * sizeof(void*));
|
68 |
| -#ifdef USE_ATEN_LIB |
69 |
| - auto deleteByNone = [](void* p) {}; |
| 70 | + |
70 | 71 | for (size_t i = 0; i < input_size; i++) {
|
71 |
| - if (!method.get_input(i).isTensor()) { |
| 72 | + if (*method_meta.input_tag(i) != Tag::Tensor) { |
72 | 73 | ET_LOG(Info, "input %zu is not a tensor, skipping", i);
|
73 | 74 | continue;
|
74 | 75 | }
|
75 |
| - const auto& t = method.get_input(i).toTensor(); |
76 |
| - at::StorageImpl* storage = |
77 |
| - t.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl(); |
78 |
| - if (storage->data_ptr().get() == nullptr) { |
79 |
| - ET_LOG(Info, "input not initialized."); |
80 |
| - inputs[num_allocated++] = malloc(t.nbytes()); |
81 |
| - storage->set_data_ptr(at::DataPtr( |
82 |
| - inputs[num_allocated - 1], |
83 |
| - inputs[num_allocated - 1], |
84 |
| - deleteByNone, |
85 |
| - DeviceType::CPU)); |
86 |
| - storage->set_nbytes(t.nbytes()); |
87 |
| - } else { |
88 |
| - ET_LOG(Info, "input already initialized, refilling."); |
| 76 | + |
| 77 | + // Tensor Input. Grab meta data and allocate buffer |
| 78 | + auto tensor_meta = method_meta.input_tensor_meta(i); |
| 79 | + inputs[num_allocated++] = malloc(tensor_meta->nbytes()); |
| 80 | + |
| 81 | +#ifdef USE_ATEN_LIB |
| 82 | + std::vector<int64_t> at_tensor_sizes; |
| 83 | + for (auto s : tensor_meta->sizes()) { |
| 84 | + at_tensor_sizes.push_back(s); |
89 | 85 | }
|
| 86 | + at::Tensor t = at::from_blob( |
| 87 | + inputs[num_allocated - 1], |
| 88 | + at_tensor_sizes, |
| 89 | + at::TensorOptions(tensor_meta->scalar_type())); |
90 | 90 | t.fill_(1.0f);
|
91 |
| - } |
92 |
| -#else |
93 |
| - for (size_t i = 0; i < input_size; i++) { |
94 |
| - if (!method.get_input(i).isTensor()) { |
95 |
| - ET_LOG(Info, "input %zu is not a tensor, skipping", i); |
96 |
| - continue; |
| 91 | + |
| 92 | +#else // Portable Tensor |
| 93 | + // The only memory that needs to persist after set_input is called is the |
| 94 | + // data ptr of the input tensor, and that is only if the Method did not |
| 95 | + // memory plan buffer space for the inputs and instead is expecting the user |
| 96 | + // to provide them. Meta data like sizes and dim order are used to ensure |
| 97 | + // the input aligns with the values expected by the plan, but references to |
| 98 | + // them are not held onto. |
| 99 | + |
| 100 | + TensorImpl::SizesType* sizes = static_cast<TensorImpl::SizesType*>( |
| 101 | + malloc(sizeof(TensorImpl::SizesType) * tensor_meta->sizes().size())); |
| 102 | + TensorImpl::DimOrderType* dim_order = |
| 103 | + static_cast<TensorImpl::DimOrderType*>(malloc( |
| 104 | + sizeof(TensorImpl::DimOrderType) * |
| 105 | + tensor_meta->dim_order().size())); |
| 106 | + |
| 107 | + for (size_t size_idx = 0; size_idx < tensor_meta->sizes().size(); |
| 108 | + size_idx++) { |
| 109 | + sizes[size_idx] = tensor_meta->sizes()[size_idx]; |
97 | 110 | }
|
98 |
| - const auto& t = method.get_input(i).toTensor(); |
99 |
| - if (t.const_data_ptr() == nullptr) { |
100 |
| - ET_LOG(Info, "input not initialized."); |
101 |
| - inputs[num_allocated++] = malloc(t.nbytes()); |
102 |
| - t.set_data(inputs[num_allocated - 1]); |
103 |
| - } else { |
104 |
| - ET_LOG(Info, "input already initialized, refilling."); |
| 111 | + for (size_t dim_idx = 0; dim_idx < tensor_meta->dim_order().size(); |
| 112 | + dim_idx++) { |
| 113 | + dim_order[dim_idx] = tensor_meta->dim_order()[dim_idx]; |
105 | 114 | }
|
| 115 | + |
| 116 | + TensorImpl impl = TensorImpl( |
| 117 | + tensor_meta->scalar_type(), |
| 118 | + tensor_meta->sizes().size(), |
| 119 | + sizes, |
| 120 | + inputs[num_allocated - 1], |
| 121 | + dim_order); |
| 122 | + Tensor t(&impl); |
106 | 123 | FillOnes(t);
|
107 |
| - } |
108 | 124 | #endif
|
| 125 | + auto error = method.set_input(t, i); |
| 126 | + ET_CHECK_MSG( |
| 127 | + error == Error::Ok, |
| 128 | + "Error: 0x%" PRIx32 " setting input %zu.", |
| 129 | + error, |
| 130 | + i); |
| 131 | +#ifndef USE_ATEN_LIB // Portable Tensor |
| 132 | + free(sizes); |
| 133 | + free(dim_order); |
| 134 | +#endif |
| 135 | + } |
109 | 136 | return {inputs, num_allocated};
|
110 | 137 | }
|
111 | 138 |
|
|
0 commit comments