Skip to content

Commit 41fd81a

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Provide more options to create an owning tensor.
Summary: . Differential Revision: D62339509
1 parent cb944b7 commit 41fd81a

File tree

6 files changed

+162
-21
lines changed

6 files changed

+162
-21
lines changed

extension/tensor/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ def define_common_targets():
2727
],
2828
deps = [
2929
"//executorch/runtime/core/exec_aten/util:dim_order_util" + aten_suffix,
30-
"//executorch/runtime/core/exec_aten/util:scalar_type_util" + aten_suffix,
3130
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
3231
],
3332
exported_deps = [
3433
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
34+
"//executorch/runtime/core/exec_aten/util:scalar_type_util" + aten_suffix,
3535
],
3636
)

extension/tensor/tensor_impl_ptr.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,5 +121,28 @@ TensorImplPtr make_tensor_impl_ptr(
121121
#endif // USE_ATEN_LIB
122122
}
123123

124+
TensorImplPtr make_tensor_impl_ptr(
125+
exec_aten::ScalarType scalar_type,
126+
std::vector<exec_aten::SizesType> sizes,
127+
std::vector<uint8_t> data,
128+
std::vector<exec_aten::DimOrderType> dim_order,
129+
std::vector<exec_aten::StridesType> strides,
130+
exec_aten::TensorShapeDynamism dynamism) {
131+
ET_CHECK_MSG(
132+
data.size() >= exec_aten::compute_numel(sizes.data(), sizes.size()) *
133+
exec_aten::elementSize(scalar_type),
134+
"Data size is smaller than required by sizes and scalar type.");
135+
auto raw_data_ptr = data.data();
136+
auto data_ptr = std::make_shared<std::vector<uint8_t>>(std::move(data));
137+
return make_tensor_impl_ptr(
138+
scalar_type,
139+
std::move(sizes),
140+
raw_data_ptr,
141+
std::move(dim_order),
142+
std::move(strides),
143+
dynamism,
144+
[data_ptr = std::move(data_ptr)](void*) {});
145+
}
146+
124147
} // namespace extension
125148
} // namespace executorch

extension/tensor/tensor_impl_ptr.h

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -74,32 +74,32 @@ TensorImplPtr make_tensor_impl_ptr(
7474
* specified properties.
7575
*
7676
* This template overload is specialized for cases where the tensor data is
77-
* provided as a vector of a specific scalar type, rather than a raw pointer.
78-
* The deleter ensures that the data vector is properly managed and its
79-
* lifetime is tied to the TensorImpl.
77+
* provided as a vector. The scalar type is automatically deduced from the
78+
* vector's data type. The deleter ensures that the data vector is properly
79+
* managed and its lifetime is tied to the TensorImpl.
8080
*
81-
* @tparam T The scalar type of the tensor elements.
81+
* @tparam T The C++ type of the tensor elements, deduced from the vector.
8282
* @param sizes A vector specifying the size of each dimension.
8383
* @param data A vector containing the tensor's data.
8484
* @param dim_order A vector specifying the order of dimensions.
8585
* @param strides A vector specifying the strides of each dimension.
8686
* @param dynamism Specifies the mutability of the tensor's shape.
87-
* @return A TensorImplPtr managing the newly created TensorImpl.
87+
* @return A TensorImplPtr that manages the newly created TensorImpl.
8888
*/
89-
template <exec_aten::ScalarType T = exec_aten::ScalarType::Float>
89+
template <typename T = float>
9090
TensorImplPtr make_tensor_impl_ptr(
9191
std::vector<exec_aten::SizesType> sizes,
92-
std::vector<typename runtime::ScalarTypeToCppType<T>::type> data,
92+
std::vector<T> data,
9393
std::vector<exec_aten::DimOrderType> dim_order = {},
9494
std::vector<exec_aten::StridesType> strides = {},
9595
exec_aten::TensorShapeDynamism dynamism =
9696
exec_aten::TensorShapeDynamism::STATIC) {
97+
constexpr exec_aten::ScalarType scalar_type =
98+
runtime::CppTypeToScalarType<T>::value;
9799
auto raw_data_ptr = data.data();
98-
auto data_ptr = std::make_shared<
99-
std::vector<typename runtime::ScalarTypeToCppType<T>::type>>(
100-
std::move(data));
100+
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
101101
return make_tensor_impl_ptr(
102-
T,
102+
scalar_type,
103103
std::move(sizes),
104104
raw_data_ptr,
105105
std::move(dim_order),
@@ -108,5 +108,30 @@ TensorImplPtr make_tensor_impl_ptr(
108108
[data_ptr = std::move(data_ptr)](void*) {});
109109
}
110110

111+
/**
112+
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
113+
* specified properties.
114+
*
115+
* This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
116+
* and a scalar type to interpret the data. The vector is managed, and the
117+
* memory's lifetime is tied to the TensorImpl.
118+
*
119+
* @param scalar_type The scalar type of the tensor elements.
120+
* @param sizes A vector specifying the size of each dimension.
121+
* @param data A vector containing the raw memory for the tensor's data.
122+
* @param dim_order A vector specifying the order of dimensions.
123+
* @param strides A vector specifying the strides of each dimension.
124+
* @param dynamism Specifies the mutability of the tensor's shape.
125+
* @return A TensorImplPtr managing the newly created TensorImpl.
126+
*/
127+
TensorImplPtr make_tensor_impl_ptr(
128+
exec_aten::ScalarType scalar_type,
129+
std::vector<exec_aten::SizesType> sizes,
130+
std::vector<uint8_t> data,
131+
std::vector<exec_aten::DimOrderType> dim_order = {},
132+
std::vector<exec_aten::StridesType> strides = {},
133+
exec_aten::TensorShapeDynamism dynamism =
134+
exec_aten::TensorShapeDynamism::STATIC);
135+
111136
} // namespace extension
112137
} // namespace executorch

extension/tensor/tensor_ptr.h

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,27 +141,59 @@ inline TensorPtr make_tensor_ptr(
141141
* Creates a TensorPtr that manages a Tensor with the specified properties.
142142
*
143143
* This template overload is specialized for cases where the tensor data is
144-
* provided as a vector of a specific scalar type, rather than a raw pointer.
145-
* The deleter ensures that the data vector is properly managed and its
146-
* lifetime is tied to the TensorImpl.
144+
* provided as a vector. The scalar type is automatically deduced from the
145+
* vector's data type. The deleter ensures that the data vector is properly
146+
* managed and its lifetime is tied to the TensorImpl.
147147
*
148-
* @tparam T The scalar type of the tensor elements.
148+
* @tparam T The C++ type of the tensor elements, deduced from the vector.
149149
* @param sizes A vector specifying the size of each dimension.
150150
* @param data A vector containing the tensor's data.
151151
* @param dim_order A vector specifying the order of dimensions.
152152
* @param strides A vector specifying the strides of each dimension.
153153
* @param dynamism Specifies the mutability of the tensor's shape.
154-
* @return A TensorImplPtr managing the newly created TensorImpl.
154+
* @return A TensorPtr that manages the newly created TensorImpl.
155155
*/
156-
template <exec_aten::ScalarType T = exec_aten::ScalarType::Float>
156+
template <typename T = float>
157157
TensorPtr make_tensor_ptr(
158158
std::vector<exec_aten::SizesType> sizes,
159-
std::vector<typename runtime::ScalarTypeToCppType<T>::type> data,
159+
std::vector<T> data,
160160
std::vector<exec_aten::DimOrderType> dim_order = {},
161161
std::vector<exec_aten::StridesType> strides = {},
162162
exec_aten::TensorShapeDynamism dynamism =
163163
exec_aten::TensorShapeDynamism::STATIC) {
164-
return make_tensor_ptr(make_tensor_impl_ptr<T>(
164+
return make_tensor_ptr(make_tensor_impl_ptr(
165+
std::move(sizes),
166+
std::move(data),
167+
std::move(dim_order),
168+
std::move(strides),
169+
dynamism));
170+
}
171+
172+
/**
173+
* Creates a TensorPtr that manages a Tensor with the specified properties.
174+
*
175+
* This overload accepts a raw memory buffer stored in a std::vector<uint8_t>
176+
* and a scalar type to interpret the data. The vector is managed, and the
177+
* memory's lifetime is tied to the TensorImpl.
178+
*
179+
* @param scalar_type The scalar type of the tensor elements.
180+
* @param sizes A vector specifying the size of each dimension.
181+
* @param data A vector containing the raw memory for the tensor's data.
182+
* @param dim_order A vector specifying the order of dimensions.
183+
* @param strides A vector specifying the strides of each dimension.
184+
* @param dynamism Specifies the mutability of the tensor's shape.
185+
* @return A TensorPtr managing the newly created Tensor.
186+
*/
187+
inline TensorPtr make_tensor_ptr(
188+
exec_aten::ScalarType scalar_type,
189+
std::vector<exec_aten::SizesType> sizes,
190+
std::vector<uint8_t> data,
191+
std::vector<exec_aten::DimOrderType> dim_order = {},
192+
std::vector<exec_aten::StridesType> strides = {},
193+
exec_aten::TensorShapeDynamism dynamism =
194+
exec_aten::TensorShapeDynamism::STATIC) {
195+
return make_tensor_ptr(make_tensor_impl_ptr(
196+
scalar_type,
165197
std::move(sizes),
166198
std::move(data),
167199
std::move(dim_order),

extension/tensor/test/targets.bzl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,5 @@ def define_common_targets():
1919
],
2020
deps = [
2121
"//executorch/extension/tensor:tensor" + aten_suffix,
22-
"//executorch/runtime/core/exec_aten/testing_util:tensor_util" + aten_suffix,
2322
],
2423
)

extension/tensor/test/tensor_impl_ptr_test.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,65 @@ TEST_F(TensorImplPtrTest, CustomDeleterWithSharedData) {
224224
EXPECT_TRUE(deleter_called);
225225
EXPECT_EQ(data.use_count(), 1);
226226
}
227+
228+
TEST_F(TensorImplPtrTest, TensorImplDeducedScalarType) {
229+
std::vector<double> data = {1.0, 2.0, 3.0, 4.0};
230+
auto tensor_impl = make_tensor_impl_ptr({2, 2}, std::move(data));
231+
232+
EXPECT_EQ(tensor_impl->dim(), 2);
233+
EXPECT_EQ(tensor_impl->size(0), 2);
234+
EXPECT_EQ(tensor_impl->size(1), 2);
235+
EXPECT_EQ(tensor_impl->strides()[0], 2);
236+
EXPECT_EQ(tensor_impl->strides()[1], 1);
237+
EXPECT_EQ(((double*)tensor_impl->data())[0], 1.0);
238+
EXPECT_EQ(((double*)tensor_impl->data())[3], 4.0);
239+
}
240+
241+
TEST_F(TensorImplPtrTest, TensorImplUint8BufferWithFloatScalarType) {
242+
std::vector<uint8_t> data(
243+
4 * exec_aten::elementSize(exec_aten::ScalarType::Float));
244+
245+
float* float_data = reinterpret_cast<float*>(data.data());
246+
float_data[0] = 1.0f;
247+
float_data[1] = 2.0f;
248+
float_data[2] = 3.0f;
249+
float_data[3] = 4.0f;
250+
251+
auto tensor_impl = make_tensor_impl_ptr(
252+
exec_aten::ScalarType::Float, {2, 2}, std::move(data));
253+
254+
EXPECT_EQ(tensor_impl->dim(), 2);
255+
EXPECT_EQ(tensor_impl->size(0), 2);
256+
EXPECT_EQ(tensor_impl->size(1), 2);
257+
EXPECT_EQ(tensor_impl->strides()[0], 2);
258+
EXPECT_EQ(tensor_impl->strides()[1], 1);
259+
260+
EXPECT_EQ(((float*)tensor_impl->data())[0], 1.0f);
261+
EXPECT_EQ(((float*)tensor_impl->data())[1], 2.0f);
262+
EXPECT_EQ(((float*)tensor_impl->data())[2], 3.0f);
263+
EXPECT_EQ(((float*)tensor_impl->data())[3], 4.0f);
264+
}
265+
266+
TEST_F(TensorImplPtrTest, TensorImplUint8BufferTooSmallExpectDeath) {
267+
std::vector<uint8_t> data(
268+
2 * exec_aten::elementSize(exec_aten::ScalarType::Float));
269+
ET_EXPECT_DEATH(
270+
{
271+
auto tensor_impl = make_tensor_impl_ptr(
272+
exec_aten::ScalarType::Float, {2, 2}, std::move(data));
273+
},
274+
"");
275+
}
276+
277+
TEST_F(TensorImplPtrTest, TensorImplUint8BufferTooLarge) {
278+
std::vector<uint8_t> data(
279+
4 * exec_aten::elementSize(exec_aten::ScalarType::Float));
280+
auto tensor_impl = make_tensor_impl_ptr(
281+
exec_aten::ScalarType::Float, {2, 2}, std::move(data));
282+
283+
EXPECT_EQ(tensor_impl->dim(), 2);
284+
EXPECT_EQ(tensor_impl->size(0), 2);
285+
EXPECT_EQ(tensor_impl->size(1), 2);
286+
EXPECT_EQ(tensor_impl->strides()[0], 2);
287+
EXPECT_EQ(tensor_impl->strides()[1], 1);
288+
}

0 commit comments

Comments
 (0)