Skip to content

Commit 145a3cd

Browse files
shoumikhinlucylq
authored andcommitted
Move type arg to the end to match Aten constructors. (#5379)
Summary: Pull Request resolved: #5379 . Reviewed By: kirklandsign Differential Revision: D62701089 fbshipit-source-id: 3f05961a43db9e6e372ee039c2d832227951fbf6
1 parent 26128d1 commit 145a3cd

File tree

7 files changed

+388
-178
lines changed

7 files changed

+388
-178
lines changed

extension/tensor/tensor_impl_ptr.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ struct TensorImplPtrDeleter final {
5454
} // namespace
5555

5656
TensorImplPtr make_tensor_impl_ptr(
57-
exec_aten::ScalarType type,
5857
std::vector<exec_aten::SizesType> sizes,
5958
void* data,
6059
std::vector<exec_aten::DimOrderType> dim_order,
6160
std::vector<exec_aten::StridesType> strides,
61+
exec_aten::ScalarType type,
6262
exec_aten::TensorShapeDynamism dynamism,
6363
std::function<void(void*)> deleter) {
6464
const auto dim = sizes.size();
@@ -129,24 +129,24 @@ TensorImplPtr make_tensor_impl_ptr(
129129
}
130130

131131
TensorImplPtr make_tensor_impl_ptr(
132-
exec_aten::ScalarType scalar_type,
133132
std::vector<exec_aten::SizesType> sizes,
134133
std::vector<uint8_t> data,
135134
std::vector<exec_aten::DimOrderType> dim_order,
136135
std::vector<exec_aten::StridesType> strides,
136+
exec_aten::ScalarType type,
137137
exec_aten::TensorShapeDynamism dynamism) {
138138
ET_CHECK_MSG(
139139
data.size() >= exec_aten::compute_numel(sizes.data(), sizes.size()) *
140-
exec_aten::elementSize(scalar_type),
140+
exec_aten::elementSize(type),
141141
"Data size is smaller than required by sizes and scalar type.");
142142
auto raw_data_ptr = data.data();
143143
auto data_ptr = std::make_shared<std::vector<uint8_t>>(std::move(data));
144144
return make_tensor_impl_ptr(
145-
scalar_type,
146145
std::move(sizes),
147146
raw_data_ptr,
148147
std::move(dim_order),
149148
std::move(strides),
149+
type,
150150
dynamism,
151151
[data_ptr = std::move(data_ptr)](void*) {});
152152
}

extension/tensor/tensor_impl_ptr.h

Lines changed: 173 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,14 @@ namespace extension {
2121

2222
#ifndef USE_ATEN_LIB
2323
/**
24-
* A smart pointer type for managing the lifecycle of a TensorImpl.
24+
* A smart pointer for managing the lifecycle of a TensorImpl.
2525
*
26-
* TensorImplPtr uses a shared pointer because multiple Tensor objects might
27-
* share the same underlying data and metadata. This shared ownership model
28-
* ensures that the TensorImpl is only destroyed when all references to it are
29-
* gone, providing a safe and efficient way to manage shared tensor
30-
* implementations. This abstraction is designed to be a safer and more
31-
* convenient alternative to the original TensorImpl, which does not
32-
* manage metadata by design.
26+
* TensorImplPtr uses a shared pointer since multiple Tensor objects may
27+
* share the same underlying data and metadata. This shared ownership ensures
28+
* that the TensorImpl is destroyed only when all references to it are gone,
29+
* providing a safe and efficient way to manage shared tensor implementations.
30+
* It serves as a safer, more convenient alternative to the original TensorImpl,
31+
* which does not manage its metadata by design.
3332
*/
3433
using TensorImplPtr = std::shared_ptr<exec_aten::TensorImpl>;
3534
#else
@@ -48,23 +47,23 @@ using TensorImplPtr =
4847
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
4948
* specified properties.
5049
*
51-
* @param type The scalar type of the tensor elements.
5250
* @param sizes A vector specifying the size of each dimension.
5351
* @param data A pointer to the data buffer.
5452
* @param dim_order A vector specifying the order of dimensions.
5553
* @param strides A vector specifying the strides of each dimension.
54+
* @param type The scalar type of the tensor elements.
5655
* @param dynamism Specifies the mutability of the tensor's shape.
5756
* @param deleter A custom deleter function for managing the lifetime of the
58-
* data buffer. If provided, this deleter will be called when the managed
59-
* TensorImpl object is destroyed.
57+
* data buffer. If provided, this deleter is called when the managed TensorImpl
58+
* is destroyed.
6059
* @return A TensorImplPtr managing the newly created TensorImpl.
6160
*/
6261
TensorImplPtr make_tensor_impl_ptr(
63-
exec_aten::ScalarType type,
6462
std::vector<exec_aten::SizesType> sizes,
6563
void* data,
66-
std::vector<exec_aten::DimOrderType> dim_order = {},
67-
std::vector<exec_aten::StridesType> strides = {},
64+
std::vector<exec_aten::DimOrderType> dim_order,
65+
std::vector<exec_aten::StridesType> strides,
66+
exec_aten::ScalarType type = exec_aten::ScalarType::Float,
6867
exec_aten::TensorShapeDynamism dynamism =
6968
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND,
7069
std::function<void(void*)> deleter = nullptr);
@@ -73,37 +72,64 @@ TensorImplPtr make_tensor_impl_ptr(
7372
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
7473
* specified properties.
7574
*
76-
* This template overload is specialized for cases where the tensor data is
77-
* provided as a vector. The scalar type is automatically deduced from the
78-
* vector's data type. The deleter ensures that the data vector is properly
79-
* managed and its lifetime is tied to the TensorImpl.
75+
* @param sizes A vector specifying the size of each dimension.
76+
* @param data A pointer to the data buffer.
77+
* @param type The scalar type of the tensor elements.
78+
* @param dynamism Specifies the mutability of the tensor's shape.
79+
* @param deleter A custom deleter function for managing the lifetime of the
80+
* data buffer. If provided, this deleter is called when the managed TensorImpl
81+
* is destroyed.
82+
* @return A TensorImplPtr managing the newly created TensorImpl.
83+
*/
84+
inline TensorImplPtr make_tensor_impl_ptr(
85+
std::vector<exec_aten::SizesType> sizes,
86+
void* data,
87+
exec_aten::ScalarType type = exec_aten::ScalarType::Float,
88+
exec_aten::TensorShapeDynamism dynamism =
89+
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND,
90+
std::function<void(void*)> deleter = nullptr) {
91+
return make_tensor_impl_ptr(
92+
std::move(sizes), data, {}, {}, type, dynamism, std::move(deleter));
93+
}
94+
95+
/**
96+
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
97+
* specified properties.
98+
*
99+
* This template overload is specialized for cases where tensor data is provided
100+
* as a vector. The scalar type is automatically deduced from the vector's data
101+
* type. The deleter ensures that the data vector is properly managed, with its
102+
* lifetime tied to the TensorImpl.
80103
*
81104
* @tparam T The C++ type of the tensor elements, deduced from the vector.
82105
* @param sizes A vector specifying the size of each dimension.
83106
* @param data A vector containing the tensor's data.
84107
* @param dim_order A vector specifying the order of dimensions.
85108
* @param strides A vector specifying the strides of each dimension.
109+
* @param type The scalar type of the tensor elements.
86110
* @param dynamism Specifies the mutability of the tensor's shape.
87111
* @return A TensorImplPtr that manages the newly created TensorImpl.
88112
*/
89-
template <typename T = float>
113+
template <
114+
typename T = float,
115+
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
90116
inline TensorImplPtr make_tensor_impl_ptr(
91117
std::vector<exec_aten::SizesType> sizes,
92118
std::vector<T> data,
93119
std::vector<exec_aten::DimOrderType> dim_order = {},
94120
std::vector<exec_aten::StridesType> strides = {},
121+
exec_aten::ScalarType type = deduced_type,
95122
exec_aten::TensorShapeDynamism dynamism =
96123
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
97-
constexpr exec_aten::ScalarType scalar_type =
98-
runtime::CppTypeToScalarType<T>::value;
124+
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
99125
const auto raw_data_ptr = data.data();
100126
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
101127
return make_tensor_impl_ptr(
102-
scalar_type,
103128
std::move(sizes),
104129
raw_data_ptr,
105130
std::move(dim_order),
106131
std::move(strides),
132+
type,
107133
dynamism,
108134
[data_ptr = std::move(data_ptr)](void*) {});
109135
}
@@ -119,43 +145,159 @@ inline TensorImplPtr make_tensor_impl_ptr(
119145
*
120146
* @tparam T The C++ type of the tensor elements, deduced from the vector.
121147
* @param data A vector containing the tensor's data.
148+
* @param type The scalar type of the tensor elements.
122149
* @param dynamism Specifies the mutability of the tensor's shape.
123150
* @return A TensorImplPtr that manages the newly created TensorImpl.
124151
*/
125-
template <typename T = float>
152+
template <
153+
typename T = float,
154+
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
126155
inline TensorImplPtr make_tensor_impl_ptr(
127156
std::vector<T> data,
157+
exec_aten::ScalarType type = deduced_type,
128158
exec_aten::TensorShapeDynamism dynamism =
129159
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
160+
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
130161
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType(data.size())};
131162
return make_tensor_impl_ptr(
132-
std::move(sizes), std::move(data), {0}, {1}, dynamism);
163+
std::move(sizes), std::move(data), {0}, {1}, type, dynamism);
164+
}
165+
166+
/**
167+
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
168+
* specified properties.
169+
*
170+
* This template overload is specialized for cases where tensor data is provided
171+
* as an initializer list. The scalar type is automatically deduced from the
172+
* initializer list's data type. The deleter ensures that the data is properly
173+
* managed, with its lifetime tied to the TensorImpl.
174+
*
175+
* @tparam T The C++ type of the tensor elements, deduced from the initializer
176+
* list.
177+
* @param sizes A vector specifying the size of each dimension.
178+
* @param list An initializer list containing the tensor's data.
179+
* @param dim_order A vector specifying the order of dimensions.
180+
* @param strides A vector specifying the strides of each dimension.
181+
* @param type The scalar type of the tensor elements.
182+
* @param dynamism Specifies the mutability of the tensor's shape.
183+
* @return A TensorImplPtr that manages the newly created TensorImpl.
184+
*/
185+
template <
186+
typename T = float,
187+
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
188+
inline TensorImplPtr make_tensor_impl_ptr(
189+
std::vector<exec_aten::SizesType> sizes,
190+
std::initializer_list<T> list,
191+
std::vector<exec_aten::DimOrderType> dim_order = {},
192+
std::vector<exec_aten::StridesType> strides = {},
193+
exec_aten::ScalarType type = deduced_type,
194+
exec_aten::TensorShapeDynamism dynamism =
195+
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
196+
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
197+
auto data = std::vector<T>(std::move(list));
198+
const auto raw_data_ptr = data.data();
199+
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
200+
return make_tensor_impl_ptr(
201+
std::move(sizes),
202+
raw_data_ptr,
203+
std::move(dim_order),
204+
std::move(strides),
205+
type,
206+
dynamism,
207+
[data_ptr = std::move(data_ptr)](void*) {});
208+
}
209+
210+
/**
211+
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
212+
* specified properties.
213+
*
214+
* This template overload is specialized for cases where the tensor data is
215+
* provided as an initializer list. The scalar type is automatically deduced
216+
* from the initializer list's data type. The deleter ensures that the data is
217+
* properly managed and its lifetime is tied to the TensorImpl.
218+
*
219+
* @tparam T The C++ type of the tensor elements, deduced from the initializer
220+
* list.
221+
* @param sizes A vector specifying the size of each dimension.
222+
* @param list An initializer list containing the tensor's data.
223+
* @param type The scalar type of the tensor elements.
224+
* @param dynamism Specifies the mutability of the tensor's shape.
225+
* @return A TensorImplPtr that manages the newly created TensorImpl.
226+
*/
227+
template <
228+
typename T = float,
229+
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
230+
inline TensorImplPtr make_tensor_impl_ptr(
231+
std::initializer_list<T> list,
232+
exec_aten::ScalarType type = deduced_type,
233+
exec_aten::TensorShapeDynamism dynamism =
234+
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
235+
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
236+
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType(list.size())};
237+
return make_tensor_impl_ptr(
238+
std::move(sizes), std::move(list), {0}, {1}, type, dynamism);
239+
}
240+
241+
/**
242+
* Creates a TensorImplPtr to manage a Tensor with a single scalar value.
243+
*
244+
* @tparam T The C++ type of the scalar value.
245+
* @param value The scalar value used for the Tensor.
246+
* @return A TensorImplPtr managing the newly created TensorImpl.
247+
*/
248+
template <typename T>
249+
inline TensorImplPtr make_tensor_impl_ptr(T value) {
250+
return make_tensor_impl_ptr({}, std::vector<T>{value});
133251
}
134252

135253
/**
136254
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
137255
* specified properties.
138256
*
139257
* This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
140-
* and a scalar type to interpret the data. The vector is managed, and the
141-
* memory's lifetime is tied to the TensorImpl.
258+
* and a scalar type to interpret the data. The vector is managed, and its
259+
* lifetime is tied to the TensorImpl.
142260
*
143-
* @param scalar_type The scalar type of the tensor elements.
144261
* @param sizes A vector specifying the size of each dimension.
145-
* @param data A vector containing the raw memory for the tensor's data.
262+
* @param data A vector containing the raw memory buffer for the tensor's data.
146263
* @param dim_order A vector specifying the order of dimensions.
147264
* @param strides A vector specifying the strides of each dimension.
265+
* @param type The scalar type of the tensor elements.
148266
* @param dynamism Specifies the mutability of the tensor's shape.
149267
* @return A TensorImplPtr managing the newly created TensorImpl.
150268
*/
151269
TensorImplPtr make_tensor_impl_ptr(
152-
exec_aten::ScalarType scalar_type,
153270
std::vector<exec_aten::SizesType> sizes,
154271
std::vector<uint8_t> data,
155-
std::vector<exec_aten::DimOrderType> dim_order = {},
156-
std::vector<exec_aten::StridesType> strides = {},
272+
std::vector<exec_aten::DimOrderType> dim_order,
273+
std::vector<exec_aten::StridesType> strides,
274+
exec_aten::ScalarType type = exec_aten::ScalarType::Float,
157275
exec_aten::TensorShapeDynamism dynamism =
158276
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
159277

278+
/**
279+
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
280+
* specified properties.
281+
*
282+
* This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
283+
* and a scalar type to interpret the data. The vector is managed, and the
284+
* memory's lifetime is tied to the TensorImpl.
285+
*
286+
* @param sizes A vector specifying the size of each dimension.
287+
* @param data A vector containing the raw memory for the tensor's data.
288+
* @param type The scalar type of the tensor elements.
289+
* @param dynamism Specifies the mutability of the tensor's shape.
290+
* @return A TensorImplPtr managing the newly created TensorImpl.
291+
*/
292+
inline TensorImplPtr make_tensor_impl_ptr(
293+
std::vector<exec_aten::SizesType> sizes,
294+
std::vector<uint8_t> data,
295+
exec_aten::ScalarType type = exec_aten::ScalarType::Float,
296+
exec_aten::TensorShapeDynamism dynamism =
297+
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
298+
return make_tensor_impl_ptr(
299+
std::move(sizes), std::move(data), {}, {}, type, dynamism);
300+
}
301+
160302
} // namespace extension
161303
} // namespace executorch

0 commit comments

Comments
 (0)