@@ -21,15 +21,14 @@ namespace extension {
21
21
22
22
#ifndef USE_ATEN_LIB
23
23
/* *
24
- * A smart pointer type for managing the lifecycle of a TensorImpl.
24
+ * A smart pointer for managing the lifecycle of a TensorImpl.
25
25
*
26
- * TensorImplPtr uses a shared pointer because multiple Tensor objects might
27
- * share the same underlying data and metadata. This shared ownership model
28
- * ensures that the TensorImpl is only destroyed when all references to it are
29
- * gone, providing a safe and efficient way to manage shared tensor
30
- * implementations. This abstraction is designed to be a safer and more
31
- * convenient alternative to the original TensorImpl, which does not
32
- * manage metadata by design.
26
+ * TensorImplPtr uses a shared pointer since multiple Tensor objects may
27
+ * share the same underlying data and metadata. This shared ownership ensures
28
+ * that the TensorImpl is destroyed only when all references to it are gone,
29
+ * providing a safe and efficient way to manage shared tensor implementations.
30
+ * It serves as a safer, more convenient alternative to the original TensorImpl,
31
+ * which does not manage its metadata by design.
33
32
*/
34
33
using TensorImplPtr = std::shared_ptr<exec_aten::TensorImpl>;
35
34
#else
@@ -48,23 +47,23 @@ using TensorImplPtr =
48
47
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
49
48
* specified properties.
50
49
*
51
- * @param type The scalar type of the tensor elements.
52
50
* @param sizes A vector specifying the size of each dimension.
53
51
* @param data A pointer to the data buffer.
54
52
* @param dim_order A vector specifying the order of dimensions.
55
53
* @param strides A vector specifying the strides of each dimension.
54
+ * @param type The scalar type of the tensor elements.
56
55
* @param dynamism Specifies the mutability of the tensor's shape.
57
56
* @param deleter A custom deleter function for managing the lifetime of the
58
- * data buffer. If provided, this deleter will be called when the managed
59
- * TensorImpl object is destroyed.
57
+ * data buffer. If provided, this deleter is called when the managed TensorImpl
58
+ * is destroyed.
60
59
* @return A TensorImplPtr managing the newly created TensorImpl.
61
60
*/
62
61
TensorImplPtr make_tensor_impl_ptr (
63
- exec_aten::ScalarType type,
64
62
std::vector<exec_aten::SizesType> sizes,
65
63
void * data,
66
- std::vector<exec_aten::DimOrderType> dim_order = {},
67
- std::vector<exec_aten::StridesType> strides = {},
64
+ std::vector<exec_aten::DimOrderType> dim_order,
65
+ std::vector<exec_aten::StridesType> strides,
66
+ exec_aten::ScalarType type = exec_aten::ScalarType::Float,
68
67
exec_aten::TensorShapeDynamism dynamism =
69
68
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND,
70
69
std::function<void (void *)> deleter = nullptr);
@@ -73,37 +72,64 @@ TensorImplPtr make_tensor_impl_ptr(
73
72
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
74
73
* specified properties.
75
74
*
76
- * This template overload is specialized for cases where the tensor data is
77
- * provided as a vector. The scalar type is automatically deduced from the
78
- * vector's data type. The deleter ensures that the data vector is properly
79
- * managed and its lifetime is tied to the TensorImpl.
75
+ * @param sizes A vector specifying the size of each dimension.
76
+ * @param data A pointer to the data buffer.
77
+ * @param type The scalar type of the tensor elements.
78
+ * @param dynamism Specifies the mutability of the tensor's shape.
79
+ * @param deleter A custom deleter function for managing the lifetime of the
80
+ * data buffer. If provided, this deleter is called when the managed TensorImpl
81
+ * is destroyed.
82
+ * @return A TensorImplPtr managing the newly created TensorImpl.
83
+ */
84
+ inline TensorImplPtr make_tensor_impl_ptr (
85
+ std::vector<exec_aten::SizesType> sizes,
86
+ void * data,
87
+ exec_aten::ScalarType type = exec_aten::ScalarType::Float,
88
+ exec_aten::TensorShapeDynamism dynamism =
89
+ exec_aten::TensorShapeDynamism::DYNAMIC_BOUND,
90
+ std::function<void (void *)> deleter = nullptr) {
91
+ return make_tensor_impl_ptr (
92
+ std::move (sizes), data, {}, {}, type, dynamism, std::move (deleter));
93
+ }
94
+
95
+ /* *
96
+ * Creates a TensorImplPtr that manages a newly created TensorImpl with the
97
+ * specified properties.
98
+ *
99
+ * This template overload is specialized for cases where tensor data is provided
100
+ * as a vector. The scalar type is automatically deduced from the vector's data
101
+ * type. The deleter ensures that the data vector is properly managed, with its
102
+ * lifetime tied to the TensorImpl.
80
103
*
81
104
* @tparam T The C++ type of the tensor elements, deduced from the vector.
82
105
* @param sizes A vector specifying the size of each dimension.
83
106
* @param data A vector containing the tensor's data.
84
107
* @param dim_order A vector specifying the order of dimensions.
85
108
* @param strides A vector specifying the strides of each dimension.
109
+ * @param type The scalar type of the tensor elements.
86
110
* @param dynamism Specifies the mutability of the tensor's shape.
87
111
* @return A TensorImplPtr that manages the newly created TensorImpl.
88
112
*/
89
- template <typename T = float >
113
+ template <
114
+ typename T = float ,
115
+ exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
90
116
inline TensorImplPtr make_tensor_impl_ptr (
91
117
std::vector<exec_aten::SizesType> sizes,
92
118
std::vector<T> data,
93
119
std::vector<exec_aten::DimOrderType> dim_order = {},
94
120
std::vector<exec_aten::StridesType> strides = {},
121
+ exec_aten::ScalarType type = deduced_type,
95
122
exec_aten::TensorShapeDynamism dynamism =
96
123
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
97
- constexpr exec_aten::ScalarType scalar_type =
98
- runtime::CppTypeToScalarType<T>::value;
124
+ ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
99
125
const auto raw_data_ptr = data.data ();
100
126
auto data_ptr = std::make_shared<std::vector<T>>(std::move (data));
101
127
return make_tensor_impl_ptr (
102
- scalar_type,
103
128
std::move (sizes),
104
129
raw_data_ptr,
105
130
std::move (dim_order),
106
131
std::move (strides),
132
+ type,
107
133
dynamism,
108
134
[data_ptr = std::move (data_ptr)](void *) {});
109
135
}
@@ -119,43 +145,159 @@ inline TensorImplPtr make_tensor_impl_ptr(
119
145
*
120
146
* @tparam T The C++ type of the tensor elements, deduced from the vector.
121
147
* @param data A vector containing the tensor's data.
148
+ * @param type The scalar type of the tensor elements.
122
149
* @param dynamism Specifies the mutability of the tensor's shape.
123
150
* @return A TensorImplPtr that manages the newly created TensorImpl.
124
151
*/
125
- template <typename T = float >
152
+ template <
153
+ typename T = float ,
154
+ exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
126
155
inline TensorImplPtr make_tensor_impl_ptr (
127
156
std::vector<T> data,
157
+ exec_aten::ScalarType type = deduced_type,
128
158
exec_aten::TensorShapeDynamism dynamism =
129
159
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
160
+ ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
130
161
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType (data.size ())};
131
162
return make_tensor_impl_ptr (
132
- std::move (sizes), std::move (data), {0 }, {1 }, dynamism);
163
+ std::move (sizes), std::move (data), {0 }, {1 }, type, dynamism);
164
+ }
165
+
166
+ /* *
167
+ * Creates a TensorImplPtr that manages a newly created TensorImpl with the
168
+ * specified properties.
169
+ *
170
+ * This template overload is specialized for cases where tensor data is provided
171
+ * as an initializer list. The scalar type is automatically deduced from the
172
+ * initializer list's data type. The deleter ensures that the data is properly
173
+ * managed, with its lifetime tied to the TensorImpl.
174
+ *
175
+ * @tparam T The C++ type of the tensor elements, deduced from the initializer
176
+ * list.
177
+ * @param sizes A vector specifying the size of each dimension.
178
+ * @param list An initializer list containing the tensor's data.
179
+ * @param dim_order A vector specifying the order of dimensions.
180
+ * @param strides A vector specifying the strides of each dimension.
181
+ * @param type The scalar type of the tensor elements.
182
+ * @param dynamism Specifies the mutability of the tensor's shape.
183
+ * @return A TensorImplPtr that manages the newly created TensorImpl.
184
+ */
185
+ template <
186
+ typename T = float ,
187
+ exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
188
+ inline TensorImplPtr make_tensor_impl_ptr (
189
+ std::vector<exec_aten::SizesType> sizes,
190
+ std::initializer_list<T> list,
191
+ std::vector<exec_aten::DimOrderType> dim_order = {},
192
+ std::vector<exec_aten::StridesType> strides = {},
193
+ exec_aten::ScalarType type = deduced_type,
194
+ exec_aten::TensorShapeDynamism dynamism =
195
+ exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
196
+ ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
197
+ auto data = std::vector<T>(std::move (list));
198
+ const auto raw_data_ptr = data.data ();
199
+ auto data_ptr = std::make_shared<std::vector<T>>(std::move (data));
200
+ return make_tensor_impl_ptr (
201
+ std::move (sizes),
202
+ raw_data_ptr,
203
+ std::move (dim_order),
204
+ std::move (strides),
205
+ type,
206
+ dynamism,
207
+ [data_ptr = std::move (data_ptr)](void *) {});
208
+ }
209
+
210
+ /* *
211
+ * Creates a TensorImplPtr that manages a newly created TensorImpl with the
212
+ * specified properties.
213
+ *
214
+ * This template overload is specialized for cases where the tensor data is
215
+ * provided as an initializer list. The scalar type is automatically deduced
216
+ * from the initializer list's data type. The deleter ensures that the data is
217
+ * properly managed and its lifetime is tied to the TensorImpl.
218
+ *
219
+ * @tparam T The C++ type of the tensor elements, deduced from the initializer
220
+ * list.
221
+ * @param sizes A vector specifying the size of each dimension.
222
+ * @param list An initializer list containing the tensor's data.
223
+ * @param type The scalar type of the tensor elements.
224
+ * @param dynamism Specifies the mutability of the tensor's shape.
225
+ * @return A TensorImplPtr that manages the newly created TensorImpl.
226
+ */
227
+ template <
228
+ typename T = float ,
229
+ exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
230
+ inline TensorImplPtr make_tensor_impl_ptr (
231
+ std::initializer_list<T> list,
232
+ exec_aten::ScalarType type = deduced_type,
233
+ exec_aten::TensorShapeDynamism dynamism =
234
+ exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
235
+ ET_CHECK_MSG (type == deduced_type, " Type does not match the deduced type." );
236
+ std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType (list.size ())};
237
+ return make_tensor_impl_ptr (
238
+ std::move (sizes), std::move (list), {0 }, {1 }, type, dynamism);
239
+ }
240
+
241
+ /* *
242
+ * Creates a TensorImplPtr to manage a Tensor with a single scalar value.
243
+ *
244
+ * @tparam T The C++ type of the scalar value.
245
+ * @param value The scalar value used for the Tensor.
246
+ * @return A TensorImplPtr managing the newly created TensorImpl.
247
+ */
248
+ template <typename T>
249
+ inline TensorImplPtr make_tensor_impl_ptr (T value) {
250
+ return make_tensor_impl_ptr ({}, std::vector<T>{value});
133
251
}
134
252
135
253
/* *
136
254
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
137
255
* specified properties.
138
256
*
139
257
* This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
140
- * and a scalar type to interpret the data. The vector is managed, and the
141
- * memory's lifetime is tied to the TensorImpl.
258
+ * and a scalar type to interpret the data. The vector is managed, and its
259
+ * lifetime is tied to the TensorImpl.
142
260
*
143
- * @param scalar_type The scalar type of the tensor elements.
144
261
* @param sizes A vector specifying the size of each dimension.
145
- * @param data A vector containing the raw memory for the tensor's data.
262
+ * @param data A vector containing the raw memory buffer for the tensor's data.
146
263
* @param dim_order A vector specifying the order of dimensions.
147
264
* @param strides A vector specifying the strides of each dimension.
265
+ * @param type The scalar type of the tensor elements.
148
266
* @param dynamism Specifies the mutability of the tensor's shape.
149
267
* @return A TensorImplPtr managing the newly created TensorImpl.
150
268
*/
151
269
TensorImplPtr make_tensor_impl_ptr (
152
- exec_aten::ScalarType scalar_type,
153
270
std::vector<exec_aten::SizesType> sizes,
154
271
std::vector<uint8_t > data,
155
- std::vector<exec_aten::DimOrderType> dim_order = {},
156
- std::vector<exec_aten::StridesType> strides = {},
272
+ std::vector<exec_aten::DimOrderType> dim_order,
273
+ std::vector<exec_aten::StridesType> strides,
274
+ exec_aten::ScalarType type = exec_aten::ScalarType::Float,
157
275
exec_aten::TensorShapeDynamism dynamism =
158
276
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
159
277
278
+ /* *
279
+ * Creates a TensorImplPtr that manages a newly created TensorImpl with the
280
+ * specified properties.
281
+ *
282
+ * This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
283
+ * and a scalar type to interpret the data. The vector is managed, and the
284
+ * memory's lifetime is tied to the TensorImpl.
285
+ *
286
+ * @param sizes A vector specifying the size of each dimension.
287
+ * @param data A vector containing the raw memory for the tensor's data.
288
+ * @param type The scalar type of the tensor elements.
289
+ * @param dynamism Specifies the mutability of the tensor's shape.
290
+ * @return A TensorImplPtr managing the newly created TensorImpl.
291
+ */
292
+ inline TensorImplPtr make_tensor_impl_ptr (
293
+ std::vector<exec_aten::SizesType> sizes,
294
+ std::vector<uint8_t > data,
295
+ exec_aten::ScalarType type = exec_aten::ScalarType::Float,
296
+ exec_aten::TensorShapeDynamism dynamism =
297
+ exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
298
+ return make_tensor_impl_ptr (
299
+ std::move (sizes), std::move (data), {}, {}, type, dynamism);
300
+ }
301
+
160
302
} // namespace extension
161
303
} // namespace executorch
0 commit comments