@@ -21,15 +21,14 @@ namespace extension {
21
21
22
22
#ifndef USE_ATEN_LIB
23
23
/* *
24
- * A smart pointer type for managing the lifecycle of a TensorImpl.
24
+ * A smart pointer for managing the lifecycle of a TensorImpl.
25
25
*
26
- * TensorImplPtr uses a shared pointer because multiple Tensor objects might
27
- * share the same underlying data and metadata. This shared ownership model
28
- * ensures that the TensorImpl is only destroyed when all references to it are
29
- * gone, providing a safe and efficient way to manage shared tensor
30
- * implementations. This abstraction is designed to be a safer and more
31
- * convenient alternative to the original TensorImpl, which does not
32
- * manage metadata by design.
26
+ * TensorImplPtr uses a shared pointer since multiple Tensor objects may
27
+ * share the same underlying data and metadata. This shared ownership ensures
28
+ * that the TensorImpl is destroyed only when all references to it are gone,
29
+ * providing a safe and efficient way to manage shared tensor implementations.
30
+ * It serves as a safer, more convenient alternative to the original TensorImpl,
31
+ * which does not manage its metadata by design.
33
32
*/
34
33
using TensorImplPtr = std::shared_ptr<exec_aten::TensorImpl>;
35
34
#else
@@ -48,23 +47,23 @@ using TensorImplPtr =
48
47
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
49
48
* specified properties.
50
49
*
51
- * @param type The scalar type of the tensor elements.
52
50
* @param sizes A vector specifying the size of each dimension.
53
51
* @param data A pointer to the data buffer.
54
52
* @param dim_order A vector specifying the order of dimensions.
55
53
* @param strides A vector specifying the strides of each dimension.
54
+ * @param type The scalar type of the tensor elements.
56
55
* @param dynamism Specifies the mutability of the tensor's shape.
57
56
* @param deleter A custom deleter function for managing the lifetime of the
58
- * data buffer. If provided, this deleter will be called when the managed
59
- * TensorImpl object is destroyed.
57
+ * data buffer. If provided, this deleter is called when the managed TensorImpl
58
+ * is destroyed.
60
59
* @return A TensorImplPtr managing the newly created TensorImpl.
61
60
*/
62
61
TensorImplPtr make_tensor_impl_ptr (
63
- exec_aten::ScalarType type,
64
62
std::vector<exec_aten::SizesType> sizes,
65
63
void * data,
66
- std::vector<exec_aten::DimOrderType> dim_order = {},
67
- std::vector<exec_aten::StridesType> strides = {},
64
+ std::vector<exec_aten::DimOrderType> dim_order,
65
+ std::vector<exec_aten::StridesType> strides,
66
+ exec_aten::ScalarType type = exec_aten::ScalarType::Float,
68
67
exec_aten::TensorShapeDynamism dynamism =
69
68
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND,
70
69
std::function<void (void *)> deleter = nullptr);
@@ -73,37 +72,64 @@ TensorImplPtr make_tensor_impl_ptr(
73
72
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
74
73
* specified properties.
75
74
*
76
- * This template overload is specialized for cases where the tensor data is
77
- * provided as a vector. The scalar type is automatically deduced from the
78
- * vector's data type. The deleter ensures that the data vector is properly
79
- * managed and its lifetime is tied to the TensorImpl.
75
+ * @param sizes A vector specifying the size of each dimension.
76
+ * @param data A pointer to the data buffer.
77
+ * @param type The scalar type of the tensor elements.
78
+ * @param dynamism Specifies the mutability of the tensor's shape.
79
+ * @param deleter A custom deleter function for managing the lifetime of the
80
+ * data buffer. If provided, this deleter is called when the managed TensorImpl
81
+ * is destroyed.
82
+ * @return A TensorImplPtr managing the newly created TensorImpl.
83
+ */
84
+ inline TensorImplPtr make_tensor_impl_ptr (
85
+ std::vector<exec_aten::SizesType> sizes,
86
+ void * data,
87
+ exec_aten::ScalarType type = exec_aten::ScalarType::Float,
88
+ exec_aten::TensorShapeDynamism dynamism =
89
+ exec_aten::TensorShapeDynamism::DYNAMIC_BOUND,
90
+ std::function<void (void *)> deleter = nullptr) {
91
+ return make_tensor_impl_ptr (
92
+ std::move (sizes), data, {}, {}, type, dynamism, std::move (deleter));
93
+ }
94
+
95
+ /* *
96
+ * Creates a TensorImplPtr that manages a newly created TensorImpl with the
97
+ * specified properties.
98
+ *
99
+ * This template overload is specialized for cases where tensor data is provided
100
+ * as a vector. The scalar type is automatically deduced from the vector's data
101
+ * type. The deleter ensures that the data vector is properly managed, with its
102
+ * lifetime tied to the TensorImpl.
80
103
*
81
104
* @tparam T The C++ type of the tensor elements, deduced from the vector.
82
105
* @param sizes A vector specifying the size of each dimension.
83
106
* @param data A vector containing the tensor's data.
84
107
* @param dim_order A vector specifying the order of dimensions.
85
108
* @param strides A vector specifying the strides of each dimension.
109
+ * @param type The scalar type of the tensor elements.
86
110
* @param dynamism Specifies the mutability of the tensor's shape.
87
111
* @return A TensorImplPtr that manages the newly created TensorImpl.
88
112
*/
89
- template <typename T = float >
90
- TensorImplPtr make_tensor_impl_ptr (
113
+ template <
114
+ typename T = float ,
115
+ exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
116
+ inline TensorImplPtr make_tensor_impl_ptr (
91
117
std::vector<exec_aten::SizesType> sizes,
92
118
std::vector<T> data,
93
119
std::vector<exec_aten::DimOrderType> dim_order = {},
94
120
std::vector<exec_aten::StridesType> strides = {},
121
+ exec_aten::ScalarType type = deduced_type,
95
122
exec_aten::TensorShapeDynamism dynamism =
96
123
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
97
- constexpr exec_aten::ScalarType scalar_type =
98
- runtime::CppTypeToScalarType<T>::value;
124
+ ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
99
125
const auto raw_data_ptr = data.data ();
100
126
auto data_ptr = std::make_shared<std::vector<T>>(std::move (data));
101
127
return make_tensor_impl_ptr (
102
- scalar_type,
103
128
std::move (sizes),
104
129
raw_data_ptr,
105
130
std::move (dim_order),
106
131
std::move (strides),
132
+ type,
107
133
dynamism,
108
134
[data_ptr = std::move (data_ptr)](void *) {});
109
135
}
@@ -119,53 +145,161 @@ TensorImplPtr make_tensor_impl_ptr(
119
145
*
120
146
* @tparam T The C++ type of the tensor elements, deduced from the vector.
121
147
* @param data A vector containing the tensor's data.
148
+ * @param type The scalar type of the tensor elements.
122
149
* @param dynamism Specifies the mutability of the tensor's shape.
123
150
* @return A TensorImplPtr that manages the newly created TensorImpl.
124
151
*/
125
- template <typename T = float >
126
- TensorImplPtr make_tensor_impl_ptr (
152
+ template <
153
+ typename T = float ,
154
+ exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
155
+ inline TensorImplPtr make_tensor_impl_ptr (
127
156
std::vector<T> data,
157
+ exec_aten::ScalarType type = deduced_type,
128
158
exec_aten::TensorShapeDynamism dynamism =
129
159
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
130
- constexpr exec_aten::ScalarType scalar_type =
131
- runtime::CppTypeToScalarType<T>::value;
160
+ ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
132
161
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType (data.size ())};
133
162
const auto raw_data_ptr = data.data ();
134
163
auto data_ptr = std::make_shared<std::vector<T>>(std::move (data));
135
164
return make_tensor_impl_ptr (
136
- scalar_type,
165
+ std::move (sizes), std::move (data), {0 }, {1 }, type, dynamism);
166
+ }
167
+
168
+ /* *
169
+ * Creates a TensorImplPtr that manages a newly created TensorImpl with the
170
+ * specified properties.
171
+ *
172
+ * This template overload is specialized for cases where tensor data is provided
173
+ * as an initializer list. The scalar type is automatically deduced from the
174
+ * initializer list's data type. The deleter ensures that the data is properly
175
+ * managed, with its lifetime tied to the TensorImpl.
176
+ *
177
+ * @tparam T The C++ type of the tensor elements, deduced from the initializer
178
+ * list.
179
+ * @param sizes A vector specifying the size of each dimension.
180
+ * @param list An initializer list containing the tensor's data.
181
+ * @param dim_order A vector specifying the order of dimensions.
182
+ * @param strides A vector specifying the strides of each dimension.
183
+ * @param type The scalar type of the tensor elements.
184
+ * @param dynamism Specifies the mutability of the tensor's shape.
185
+ * @return A TensorImplPtr that manages the newly created TensorImpl.
186
+ */
187
+ template <
188
+ typename T = float ,
189
+ exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
190
+ inline TensorImplPtr make_tensor_impl_ptr (
191
+ std::vector<exec_aten::SizesType> sizes,
192
+ std::initializer_list<T> list,
193
+ std::vector<exec_aten::DimOrderType> dim_order = {},
194
+ std::vector<exec_aten::StridesType> strides = {},
195
+ exec_aten::ScalarType type = deduced_type,
196
+ exec_aten::TensorShapeDynamism dynamism =
197
+ exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
198
+ ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
199
+ auto data = std::vector<T>(std::move (list));
200
+ const auto raw_data_ptr = data.data ();
201
+ auto data_ptr = std::make_shared<std::vector<T>>(std::move (data));
202
+ return make_tensor_impl_ptr (
137
203
std::move (sizes),
138
204
raw_data_ptr,
139
- {0 },
140
- {1 },
205
+ std::move (dim_order),
206
+ std::move (strides),
207
+ type,
141
208
dynamism,
142
209
[data_ptr = std::move (data_ptr)](void *) {});
143
210
}
144
211
212
+ /* *
213
+ * Creates a TensorImplPtr that manages a newly created TensorImpl with the
214
+ * specified properties.
215
+ *
216
+ * This template overload is specialized for cases where the tensor data is
217
+ * provided as an initializer list. The scalar type is automatically deduced
218
+ * from the initializer list's data type. The deleter ensures that the data is
219
+ * properly managed and its lifetime is tied to the TensorImpl.
220
+ *
221
+ * @tparam T The C++ type of the tensor elements, deduced from the initializer
222
+ * list.
223
+ * @param sizes A vector specifying the size of each dimension.
224
+ * @param list An initializer list containing the tensor's data.
225
+ * @param type The scalar type of the tensor elements.
226
+ * @param dynamism Specifies the mutability of the tensor's shape.
227
+ * @return A TensorImplPtr that manages the newly created TensorImpl.
228
+ */
229
+ template <
230
+ typename T = float ,
231
+ exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
232
+ inline TensorImplPtr make_tensor_impl_ptr (
233
+ std::initializer_list<T> list,
234
+ exec_aten::ScalarType type = deduced_type,
235
+ exec_aten::TensorShapeDynamism dynamism =
236
+ exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
237
+ ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
238
+ std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType (list.size ())};
239
+ return make_tensor_impl_ptr (
240
+ std::move (sizes), std::move (list), {0 }, {1 }, type, dynamism);
241
+ }
242
+
243
+ /* *
244
+ * Creates a TensorImplPtr to manage a Tensor with a single scalar value.
245
+ *
246
+ * @tparam T The C++ type of the scalar value.
247
+ * @param value The scalar value used for the Tensor.
248
+ * @return A TensorImplPtr managing the newly created TensorImpl.
249
+ */
250
+ template <typename T>
251
+ inline TensorImplPtr make_tensor_impl_ptr (T value) {
252
+ return make_tensor_impl_ptr ({}, std::vector<T>{value});
253
+ }
254
+
145
255
/* *
146
256
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
147
257
* specified properties.
148
258
*
149
259
* This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
150
- * and a scalar type to interpret the data. The vector is managed, and the
151
- * memory's lifetime is tied to the TensorImpl.
260
+ * and a scalar type to interpret the data. The vector is managed, and its
261
+ * lifetime is tied to the TensorImpl.
152
262
*
153
- * @param scalar_type The scalar type of the tensor elements.
154
263
* @param sizes A vector specifying the size of each dimension.
155
- * @param data A vector containing the raw memory for the tensor's data.
264
+ * @param data A vector containing the raw memory buffer for the tensor's data.
156
265
* @param dim_order A vector specifying the order of dimensions.
157
266
* @param strides A vector specifying the strides of each dimension.
267
+ * @param type The scalar type of the tensor elements.
158
268
* @param dynamism Specifies the mutability of the tensor's shape.
159
269
* @return A TensorImplPtr managing the newly created TensorImpl.
160
270
*/
161
271
TensorImplPtr make_tensor_impl_ptr (
162
- exec_aten::ScalarType scalar_type,
163
272
std::vector<exec_aten::SizesType> sizes,
164
273
std::vector<uint8_t > data,
165
- std::vector<exec_aten::DimOrderType> dim_order = {},
166
- std::vector<exec_aten::StridesType> strides = {},
274
+ std::vector<exec_aten::DimOrderType> dim_order,
275
+ std::vector<exec_aten::StridesType> strides,
276
+ exec_aten::ScalarType type = exec_aten::ScalarType::Float,
167
277
exec_aten::TensorShapeDynamism dynamism =
168
278
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
169
279
280
+ /* *
281
+ * Creates a TensorImplPtr that manages a newly created TensorImpl with the
282
+ * specified properties.
283
+ *
284
+ * This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
285
+ * and a scalar type to interpret the data. The vector is managed, and the
286
+ * memory's lifetime is tied to the TensorImpl.
287
+ *
288
+ * @param sizes A vector specifying the size of each dimension.
289
+ * @param data A vector containing the raw memory for the tensor's data.
290
+ * @param type The scalar type of the tensor elements.
291
+ * @param dynamism Specifies the mutability of the tensor's shape.
292
+ * @return A TensorImplPtr managing the newly created TensorImpl.
293
+ */
294
+ inline TensorImplPtr make_tensor_impl_ptr (
295
+ std::vector<exec_aten::SizesType> sizes,
296
+ std::vector<uint8_t > data,
297
+ exec_aten::ScalarType type = exec_aten::ScalarType::Float,
298
+ exec_aten::TensorShapeDynamism dynamism =
299
+ exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
300
+ return make_tensor_impl_ptr (
301
+ std::move (sizes), std::move (data), {}, {}, type, dynamism);
302
+ }
303
+
170
304
} // namespace extension
171
305
} // namespace executorch
0 commit comments