Skip to content

Commit e8c906e

Browse files
shoumikhinpytorchbot
authored andcommitted
Move type arg to the end to match Aten constructors. (#5379)
Summary: Pull Request resolved: #5379 . Reviewed By: kirklandsign Differential Revision: D62701089 fbshipit-source-id: 3f05961a43db9e6e372ee039c2d832227951fbf6 (cherry picked from commit c252553)
1 parent eecf74f commit e8c906e

File tree

7 files changed

+402
-172
lines changed

7 files changed

+402
-172
lines changed

extension/tensor/tensor_impl_ptr.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ struct TensorImplPtrDeleter final {
5454
} // namespace
5555

5656
TensorImplPtr make_tensor_impl_ptr(
57-
exec_aten::ScalarType type,
5857
std::vector<exec_aten::SizesType> sizes,
5958
void* data,
6059
std::vector<exec_aten::DimOrderType> dim_order,
6160
std::vector<exec_aten::StridesType> strides,
61+
exec_aten::ScalarType type,
6262
exec_aten::TensorShapeDynamism dynamism,
6363
std::function<void(void*)> deleter) {
6464
const auto dim = sizes.size();
@@ -122,24 +122,24 @@ TensorImplPtr make_tensor_impl_ptr(
122122
}
123123

124124
TensorImplPtr make_tensor_impl_ptr(
125-
exec_aten::ScalarType scalar_type,
126125
std::vector<exec_aten::SizesType> sizes,
127126
std::vector<uint8_t> data,
128127
std::vector<exec_aten::DimOrderType> dim_order,
129128
std::vector<exec_aten::StridesType> strides,
129+
exec_aten::ScalarType type,
130130
exec_aten::TensorShapeDynamism dynamism) {
131131
ET_CHECK_MSG(
132132
data.size() >= exec_aten::compute_numel(sizes.data(), sizes.size()) *
133-
exec_aten::elementSize(scalar_type),
133+
exec_aten::elementSize(type),
134134
"Data size is smaller than required by sizes and scalar type.");
135135
auto raw_data_ptr = data.data();
136136
auto data_ptr = std::make_shared<std::vector<uint8_t>>(std::move(data));
137137
return make_tensor_impl_ptr(
138-
scalar_type,
139138
std::move(sizes),
140139
raw_data_ptr,
141140
std::move(dim_order),
142141
std::move(strides),
142+
type,
143143
dynamism,
144144
[data_ptr = std::move(data_ptr)](void*) {});
145145
}

extension/tensor/tensor_impl_ptr.h

Lines changed: 171 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,14 @@ namespace extension {
2121

2222
#ifndef USE_ATEN_LIB
2323
/**
24-
* A smart pointer type for managing the lifecycle of a TensorImpl.
24+
* A smart pointer for managing the lifecycle of a TensorImpl.
2525
*
26-
* TensorImplPtr uses a shared pointer because multiple Tensor objects might
27-
* share the same underlying data and metadata. This shared ownership model
28-
* ensures that the TensorImpl is only destroyed when all references to it are
29-
* gone, providing a safe and efficient way to manage shared tensor
30-
* implementations. This abstraction is designed to be a safer and more
31-
* convenient alternative to the original TensorImpl, which does not
32-
* manage metadata by design.
26+
* TensorImplPtr uses a shared pointer since multiple Tensor objects may
27+
* share the same underlying data and metadata. This shared ownership ensures
28+
* that the TensorImpl is destroyed only when all references to it are gone,
29+
* providing a safe and efficient way to manage shared tensor implementations.
30+
* It serves as a safer, more convenient alternative to the original TensorImpl,
31+
* which does not manage its metadata by design.
3332
*/
3433
using TensorImplPtr = std::shared_ptr<exec_aten::TensorImpl>;
3534
#else
@@ -48,23 +47,23 @@ using TensorImplPtr =
4847
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
4948
* specified properties.
5049
*
51-
* @param type The scalar type of the tensor elements.
5250
* @param sizes A vector specifying the size of each dimension.
5351
* @param data A pointer to the data buffer.
5452
* @param dim_order A vector specifying the order of dimensions.
5553
* @param strides A vector specifying the strides of each dimension.
54+
* @param type The scalar type of the tensor elements.
5655
* @param dynamism Specifies the mutability of the tensor's shape.
5756
* @param deleter A custom deleter function for managing the lifetime of the
58-
* data buffer. If provided, this deleter will be called when the managed
59-
* TensorImpl object is destroyed.
57+
* data buffer. If provided, this deleter is called when the managed TensorImpl
58+
* is destroyed.
6059
* @return A TensorImplPtr managing the newly created TensorImpl.
6160
*/
6261
TensorImplPtr make_tensor_impl_ptr(
63-
exec_aten::ScalarType type,
6462
std::vector<exec_aten::SizesType> sizes,
6563
void* data,
66-
std::vector<exec_aten::DimOrderType> dim_order = {},
67-
std::vector<exec_aten::StridesType> strides = {},
64+
std::vector<exec_aten::DimOrderType> dim_order,
65+
std::vector<exec_aten::StridesType> strides,
66+
exec_aten::ScalarType type = exec_aten::ScalarType::Float,
6867
exec_aten::TensorShapeDynamism dynamism =
6968
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND,
7069
std::function<void(void*)> deleter = nullptr);
@@ -73,37 +72,64 @@ TensorImplPtr make_tensor_impl_ptr(
7372
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
7473
* specified properties.
7574
*
76-
* This template overload is specialized for cases where the tensor data is
77-
* provided as a vector. The scalar type is automatically deduced from the
78-
* vector's data type. The deleter ensures that the data vector is properly
79-
* managed and its lifetime is tied to the TensorImpl.
75+
* @param sizes A vector specifying the size of each dimension.
76+
* @param data A pointer to the data buffer.
77+
* @param type The scalar type of the tensor elements.
78+
* @param dynamism Specifies the mutability of the tensor's shape.
79+
* @param deleter A custom deleter function for managing the lifetime of the
80+
* data buffer. If provided, this deleter is called when the managed TensorImpl
81+
* is destroyed.
82+
* @return A TensorImplPtr managing the newly created TensorImpl.
83+
*/
84+
inline TensorImplPtr make_tensor_impl_ptr(
85+
std::vector<exec_aten::SizesType> sizes,
86+
void* data,
87+
exec_aten::ScalarType type = exec_aten::ScalarType::Float,
88+
exec_aten::TensorShapeDynamism dynamism =
89+
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND,
90+
std::function<void(void*)> deleter = nullptr) {
91+
return make_tensor_impl_ptr(
92+
std::move(sizes), data, {}, {}, type, dynamism, std::move(deleter));
93+
}
94+
95+
/**
96+
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
97+
* specified properties.
98+
*
99+
* This template overload is specialized for cases where tensor data is provided
100+
* as a vector. The scalar type is automatically deduced from the vector's data
101+
* type. The deleter ensures that the data vector is properly managed, with its
102+
* lifetime tied to the TensorImpl.
80103
*
81104
* @tparam T The C++ type of the tensor elements, deduced from the vector.
82105
* @param sizes A vector specifying the size of each dimension.
83106
* @param data A vector containing the tensor's data.
84107
* @param dim_order A vector specifying the order of dimensions.
85108
* @param strides A vector specifying the strides of each dimension.
109+
* @param type The scalar type of the tensor elements.
86110
* @param dynamism Specifies the mutability of the tensor's shape.
87111
* @return A TensorImplPtr that manages the newly created TensorImpl.
88112
*/
89-
template <typename T = float>
90-
TensorImplPtr make_tensor_impl_ptr(
113+
template <
114+
typename T = float,
115+
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
116+
inline TensorImplPtr make_tensor_impl_ptr(
91117
std::vector<exec_aten::SizesType> sizes,
92118
std::vector<T> data,
93119
std::vector<exec_aten::DimOrderType> dim_order = {},
94120
std::vector<exec_aten::StridesType> strides = {},
121+
exec_aten::ScalarType type = deduced_type,
95122
exec_aten::TensorShapeDynamism dynamism =
96123
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
97-
constexpr exec_aten::ScalarType scalar_type =
98-
runtime::CppTypeToScalarType<T>::value;
124+
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
99125
const auto raw_data_ptr = data.data();
100126
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
101127
return make_tensor_impl_ptr(
102-
scalar_type,
103128
std::move(sizes),
104129
raw_data_ptr,
105130
std::move(dim_order),
106131
std::move(strides),
132+
type,
107133
dynamism,
108134
[data_ptr = std::move(data_ptr)](void*) {});
109135
}
@@ -119,53 +145,161 @@ TensorImplPtr make_tensor_impl_ptr(
119145
*
120146
* @tparam T The C++ type of the tensor elements, deduced from the vector.
121147
* @param data A vector containing the tensor's data.
148+
* @param type The scalar type of the tensor elements.
122149
* @param dynamism Specifies the mutability of the tensor's shape.
123150
* @return A TensorImplPtr that manages the newly created TensorImpl.
124151
*/
125-
template <typename T = float>
126-
TensorImplPtr make_tensor_impl_ptr(
152+
template <
153+
typename T = float,
154+
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
155+
inline TensorImplPtr make_tensor_impl_ptr(
127156
std::vector<T> data,
157+
exec_aten::ScalarType type = deduced_type,
128158
exec_aten::TensorShapeDynamism dynamism =
129159
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
130-
constexpr exec_aten::ScalarType scalar_type =
131-
runtime::CppTypeToScalarType<T>::value;
160+
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
132161
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType(data.size())};
133162
const auto raw_data_ptr = data.data();
134163
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
135164
return make_tensor_impl_ptr(
136-
scalar_type,
165+
std::move(sizes), std::move(data), {0}, {1}, type, dynamism);
166+
}
167+
168+
/**
169+
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
170+
* specified properties.
171+
*
172+
* This template overload is specialized for cases where tensor data is provided
173+
* as an initializer list. The scalar type is automatically deduced from the
174+
* initializer list's data type. The deleter ensures that the data is properly
175+
* managed, with its lifetime tied to the TensorImpl.
176+
*
177+
* @tparam T The C++ type of the tensor elements, deduced from the initializer
178+
* list.
179+
* @param sizes A vector specifying the size of each dimension.
180+
* @param list An initializer list containing the tensor's data.
181+
* @param dim_order A vector specifying the order of dimensions.
182+
* @param strides A vector specifying the strides of each dimension.
183+
* @param type The scalar type of the tensor elements.
184+
* @param dynamism Specifies the mutability of the tensor's shape.
185+
* @return A TensorImplPtr that manages the newly created TensorImpl.
186+
*/
187+
template <
188+
typename T = float,
189+
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
190+
inline TensorImplPtr make_tensor_impl_ptr(
191+
std::vector<exec_aten::SizesType> sizes,
192+
std::initializer_list<T> list,
193+
std::vector<exec_aten::DimOrderType> dim_order = {},
194+
std::vector<exec_aten::StridesType> strides = {},
195+
exec_aten::ScalarType type = deduced_type,
196+
exec_aten::TensorShapeDynamism dynamism =
197+
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
198+
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
199+
auto data = std::vector<T>(std::move(list));
200+
const auto raw_data_ptr = data.data();
201+
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
202+
return make_tensor_impl_ptr(
137203
std::move(sizes),
138204
raw_data_ptr,
139-
{0},
140-
{1},
205+
std::move(dim_order),
206+
std::move(strides),
207+
type,
141208
dynamism,
142209
[data_ptr = std::move(data_ptr)](void*) {});
143210
}
144211

212+
/**
213+
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
214+
* specified properties.
215+
*
216+
* This template overload is specialized for cases where the tensor data is
217+
* provided as an initializer list. The scalar type is automatically deduced
218+
* from the initializer list's data type. The deleter ensures that the data is
219+
* properly managed and its lifetime is tied to the TensorImpl.
220+
*
221+
* @tparam T The C++ type of the tensor elements, deduced from the initializer
222+
* list.
223+
* @param sizes A vector specifying the size of each dimension.
224+
* @param list An initializer list containing the tensor's data.
225+
* @param type The scalar type of the tensor elements.
226+
* @param dynamism Specifies the mutability of the tensor's shape.
227+
* @return A TensorImplPtr that manages the newly created TensorImpl.
228+
*/
229+
template <
230+
typename T = float,
231+
exec_aten::ScalarType deduced_type = runtime::CppTypeToScalarType<T>::value>
232+
inline TensorImplPtr make_tensor_impl_ptr(
233+
std::initializer_list<T> list,
234+
exec_aten::ScalarType type = deduced_type,
235+
exec_aten::TensorShapeDynamism dynamism =
236+
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
237+
ET_CHECK_MSG(type == deduced_type, "Type does not match the deduced type.");
238+
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType(list.size())};
239+
return make_tensor_impl_ptr(
240+
std::move(sizes), std::move(list), {0}, {1}, type, dynamism);
241+
}
242+
243+
/**
244+
* Creates a TensorImplPtr to manage a Tensor with a single scalar value.
245+
*
246+
* @tparam T The C++ type of the scalar value.
247+
* @param value The scalar value used for the Tensor.
248+
* @return A TensorImplPtr managing the newly created TensorImpl.
249+
*/
250+
template <typename T>
251+
inline TensorImplPtr make_tensor_impl_ptr(T value) {
252+
return make_tensor_impl_ptr({}, std::vector<T>{value});
253+
}
254+
145255
/**
146256
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
147257
* specified properties.
148258
*
149259
* This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
150-
* and a scalar type to interpret the data. The vector is managed, and the
151-
* memory's lifetime is tied to the TensorImpl.
260+
* and a scalar type to interpret the data. The vector is managed, and its
261+
* lifetime is tied to the TensorImpl.
152262
*
153-
* @param scalar_type The scalar type of the tensor elements.
154263
* @param sizes A vector specifying the size of each dimension.
155-
* @param data A vector containing the raw memory for the tensor's data.
264+
* @param data A vector containing the raw memory buffer for the tensor's data.
156265
* @param dim_order A vector specifying the order of dimensions.
157266
* @param strides A vector specifying the strides of each dimension.
267+
* @param type The scalar type of the tensor elements.
158268
* @param dynamism Specifies the mutability of the tensor's shape.
159269
* @return A TensorImplPtr managing the newly created TensorImpl.
160270
*/
161271
TensorImplPtr make_tensor_impl_ptr(
162-
exec_aten::ScalarType scalar_type,
163272
std::vector<exec_aten::SizesType> sizes,
164273
std::vector<uint8_t> data,
165-
std::vector<exec_aten::DimOrderType> dim_order = {},
166-
std::vector<exec_aten::StridesType> strides = {},
274+
std::vector<exec_aten::DimOrderType> dim_order,
275+
std::vector<exec_aten::StridesType> strides,
276+
exec_aten::ScalarType type = exec_aten::ScalarType::Float,
167277
exec_aten::TensorShapeDynamism dynamism =
168278
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND);
169279

280+
/**
281+
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
282+
* specified properties.
283+
*
284+
* This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
285+
* and a scalar type to interpret the data. The vector is managed, and the
286+
* memory's lifetime is tied to the TensorImpl.
287+
*
288+
* @param sizes A vector specifying the size of each dimension.
289+
* @param data A vector containing the raw memory for the tensor's data.
290+
* @param type The scalar type of the tensor elements.
291+
* @param dynamism Specifies the mutability of the tensor's shape.
292+
* @return A TensorImplPtr managing the newly created TensorImpl.
293+
*/
294+
inline TensorImplPtr make_tensor_impl_ptr(
295+
std::vector<exec_aten::SizesType> sizes,
296+
std::vector<uint8_t> data,
297+
exec_aten::ScalarType type = exec_aten::ScalarType::Float,
298+
exec_aten::TensorShapeDynamism dynamism =
299+
exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) {
300+
return make_tensor_impl_ptr(
301+
std::move(sizes), std::move(data), {}, {}, type, dynamism);
302+
}
303+
170304
} // namespace extension
171305
} // namespace executorch

0 commit comments

Comments
 (0)