Skip to content

Commit 9b05045

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Add an overload to skip dtype and sizes.
Summary: . Differential Revision: D62366751
1 parent 99fbca3 commit 9b05045

File tree

4 files changed

+194
-3
lines changed

4 files changed

+194
-3
lines changed

extension/tensor/tensor_impl_ptr.h

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ TensorImplPtr make_tensor_impl_ptr(
9696
exec_aten::TensorShapeDynamism::STATIC) {
9797
constexpr exec_aten::ScalarType scalar_type =
9898
runtime::CppTypeToScalarType<T>::value;
99-
auto raw_data_ptr = data.data();
99+
const auto raw_data_ptr = data.data();
100100
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
101101
return make_tensor_impl_ptr(
102102
scalar_type,
@@ -108,6 +108,40 @@ TensorImplPtr make_tensor_impl_ptr(
108108
[data_ptr = std::move(data_ptr)](void*) {});
109109
}
110110

111+
/**
112+
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
113+
* specified properties.
114+
*
115+
* This template overload is specialized for cases where the tensor data is
116+
* provided as a vector. The scalar type is automatically deduced from the
117+
* vector's data type. The deleter ensures that the data vector is properly
118+
* managed and its lifetime is tied to the TensorImpl.
119+
*
120+
* @tparam T The C++ type of the tensor elements, deduced from the vector.
121+
* @param data A vector containing the tensor's data.
122+
* @param dynamism Specifies the mutability of the tensor's shape.
123+
* @return A TensorImplPtr that manages the newly created TensorImpl.
124+
*/
125+
template <typename T = float>
126+
TensorImplPtr make_tensor_impl_ptr(
127+
std::vector<T> data,
128+
exec_aten::TensorShapeDynamism dynamism =
129+
exec_aten::TensorShapeDynamism::STATIC) {
130+
constexpr exec_aten::ScalarType scalar_type =
131+
runtime::CppTypeToScalarType<T>::value;
132+
std::vector<exec_aten::SizesType> sizes{exec_aten::SizesType(data.size())};
133+
const auto raw_data_ptr = data.data();
134+
auto data_ptr = std::make_shared<std::vector<T>>(std::move(data));
135+
return make_tensor_impl_ptr(
136+
scalar_type,
137+
std::move(sizes),
138+
raw_data_ptr,
139+
{0},
140+
{1},
141+
dynamism,
142+
[data_ptr = std::move(data_ptr)](void*) {});
143+
}
144+
111145
/**
112146
* Creates a TensorImplPtr that manages a newly created TensorImpl with the
113147
* specified properties.

extension/tensor/tensor_ptr.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,27 @@ TensorPtr make_tensor_ptr(
169169
dynamism));
170170
}
171171

172+
/**
173+
* Creates a TensorPtr that manages a Tensor with the specified properties.
174+
*
175+
* This template overload is specialized for cases where the tensor data is
176+
* provided as a vector. The scalar type is automatically deduced from the
177+
* vector's data type. The deleter ensures that the data vector is properly
178+
* managed and its lifetime is tied to the TensorImpl.
179+
*
180+
* @tparam T The C++ type of the tensor elements, deduced from the vector.
181+
* @param data A vector containing the tensor's data.
182+
* @param dynamism Specifies the mutability of the tensor's shape.
183+
* @return A TensorPtr that manages the newly created TensorImpl.
184+
*/
185+
template <typename T = float>
186+
TensorPtr make_tensor_ptr(
187+
std::vector<T> data,
188+
exec_aten::TensorShapeDynamism dynamism =
189+
exec_aten::TensorShapeDynamism::STATIC) {
190+
return make_tensor_ptr(make_tensor_impl_ptr(std::move(data), dynamism));
191+
}
192+
172193
/**
173194
* Creates a TensorPtr that manages a Tensor with the specified properties.
174195
*

extension/tensor/test/tensor_impl_ptr_test.cpp

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ TEST_F(TensorImplPtrTest, TensorImplOwningData) {
172172
}
173173

174174
TEST_F(TensorImplPtrTest, TensorImplOwningEmptyData) {
175-
auto tensor_impl = make_tensor_impl_ptr({0, 5}, {});
175+
auto tensor_impl = make_tensor_impl_ptr({0, 5}, std::vector<float>());
176176

177177
EXPECT_EQ(tensor_impl->dim(), 2);
178178
EXPECT_EQ(tensor_impl->size(0), 0);
@@ -182,6 +182,74 @@ TEST_F(TensorImplPtrTest, TensorImplOwningEmptyData) {
182182
EXPECT_EQ(tensor_impl->data(), nullptr);
183183
}
184184

185+
TEST_F(TensorImplPtrTest, TensorImplDataOnlyDoubleType) {
186+
std::vector<double> data = {1.0, 2.0, 3.0, 4.0};
187+
auto tensor_impl = make_tensor_impl_ptr(std::move(data));
188+
189+
EXPECT_EQ(tensor_impl->dim(), 1);
190+
EXPECT_EQ(tensor_impl->size(0), 4);
191+
EXPECT_EQ(tensor_impl->strides()[0], 1);
192+
EXPECT_EQ(((double*)tensor_impl->data())[0], 1.0);
193+
EXPECT_EQ(((double*)tensor_impl->data())[3], 4.0);
194+
}
195+
196+
TEST_F(TensorImplPtrTest, TensorImplDataOnlyInt32Type) {
197+
std::vector<int32_t> data = {10, 20, 30, 40};
198+
auto tensor_impl = make_tensor_impl_ptr(std::move(data));
199+
200+
EXPECT_EQ(tensor_impl->dim(), 1);
201+
EXPECT_EQ(tensor_impl->size(0), 4);
202+
EXPECT_EQ(tensor_impl->strides()[0], 1);
203+
EXPECT_EQ(((int32_t*)tensor_impl->data())[0], 10);
204+
EXPECT_EQ(((int32_t*)tensor_impl->data())[3], 40);
205+
}
206+
207+
TEST_F(TensorImplPtrTest, TensorImplDataOnlyInt64Type) {
208+
std::vector<int64_t> data = {100, 200, 300, 400};
209+
auto tensor_impl = make_tensor_impl_ptr(std::move(data));
210+
211+
EXPECT_EQ(tensor_impl->dim(), 1);
212+
EXPECT_EQ(tensor_impl->size(0), 4);
213+
EXPECT_EQ(tensor_impl->strides()[0], 1);
214+
EXPECT_EQ(((int64_t*)tensor_impl->data())[0], 100);
215+
EXPECT_EQ(((int64_t*)tensor_impl->data())[3], 400);
216+
}
217+
218+
TEST_F(TensorImplPtrTest, TensorImplDataOnlyUint8Type) {
219+
std::vector<uint8_t> data = {10, 20, 30, 40};
220+
auto tensor_impl = make_tensor_impl_ptr(std::move(data));
221+
222+
EXPECT_EQ(tensor_impl->dim(), 1);
223+
EXPECT_EQ(tensor_impl->size(0), 4);
224+
EXPECT_EQ(tensor_impl->strides()[0], 1);
225+
EXPECT_EQ(((uint8_t*)tensor_impl->data())[0], 10);
226+
EXPECT_EQ(((uint8_t*)tensor_impl->data())[3], 40);
227+
}
228+
229+
TEST_F(TensorImplPtrTest, TensorImplAmbiguityWithMixedVectors) {
230+
std::vector<exec_aten::SizesType> sizes = {2, 2};
231+
std::vector<float> data = {1.0f, 2.0f, 3.0f, 4.0f};
232+
auto tensor_impl = make_tensor_impl_ptr(std::move(sizes), std::move(data));
233+
234+
EXPECT_EQ(tensor_impl->dim(), 2);
235+
EXPECT_EQ(tensor_impl->size(0), 2);
236+
EXPECT_EQ(tensor_impl->size(1), 2);
237+
EXPECT_EQ(tensor_impl->strides()[0], 2);
238+
EXPECT_EQ(tensor_impl->strides()[1], 1);
239+
EXPECT_EQ(((float*)tensor_impl->data())[0], 1.0f);
240+
EXPECT_EQ(((float*)tensor_impl->data())[3], 4.0f);
241+
242+
auto tensor_impl2 = make_tensor_impl_ptr({2, 2}, {1.0f, 2.0f, 3.0f, 4.0f});
243+
244+
EXPECT_EQ(tensor_impl2->dim(), 2);
245+
EXPECT_EQ(tensor_impl2->size(0), 2);
246+
EXPECT_EQ(tensor_impl2->size(1), 2);
247+
EXPECT_EQ(tensor_impl2->strides()[0], 2);
248+
EXPECT_EQ(tensor_impl2->strides()[1], 1);
249+
EXPECT_EQ(((float*)tensor_impl2->data())[0], 1.0f);
250+
EXPECT_EQ(((float*)tensor_impl2->data())[3], 4.0f);
251+
}
252+
185253
TEST_F(TensorImplPtrTest, SharedDataManagement) {
186254
auto data = std::make_shared<std::vector<float>>(100, 1.0f);
187255
auto tensor_impl1 = make_tensor_impl_ptr(

extension/tensor/test/tensor_ptr_test.cpp

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ TEST_F(TensorPtrTest, TensorOwningData) {
167167
}
168168

169169
TEST_F(TensorPtrTest, TensorOwningEmptyData) {
170-
auto tensor = make_tensor_ptr({0, 5}, {});
170+
auto tensor = make_tensor_ptr({0, 5}, std::vector<float>());
171171

172172
EXPECT_EQ(tensor->dim(), 2);
173173
EXPECT_EQ(tensor->size(0), 0);
@@ -177,6 +177,74 @@ TEST_F(TensorPtrTest, TensorOwningEmptyData) {
177177
EXPECT_EQ(tensor->data_ptr<float>(), nullptr);
178178
}
179179

180+
TEST_F(TensorPtrTest, TensorImplDataOnlyDoubleType) {
181+
std::vector<double> data = {1.0, 2.0, 3.0, 4.0};
182+
auto tensor = make_tensor_ptr(std::move(data));
183+
184+
EXPECT_EQ(tensor->dim(), 1);
185+
EXPECT_EQ(tensor->size(0), 4);
186+
EXPECT_EQ(tensor->strides()[0], 1);
187+
EXPECT_EQ(tensor->const_data_ptr<double>()[0], 1.0);
188+
EXPECT_EQ(tensor->const_data_ptr<double>()[3], 4.0);
189+
}
190+
191+
TEST_F(TensorPtrTest, TensorImplDataOnlyInt32Type) {
192+
std::vector<int32_t> data = {10, 20, 30, 40};
193+
auto tensor = make_tensor_ptr(std::move(data));
194+
195+
EXPECT_EQ(tensor->dim(), 1);
196+
EXPECT_EQ(tensor->size(0), 4);
197+
EXPECT_EQ(tensor->strides()[0], 1);
198+
EXPECT_EQ(tensor->const_data_ptr<int32_t>()[0], 10);
199+
EXPECT_EQ(tensor->const_data_ptr<int32_t>()[3], 40);
200+
}
201+
202+
TEST_F(TensorPtrTest, TensorImplDataOnlyInt64Type) {
203+
std::vector<int64_t> data = {100, 200, 300, 400};
204+
auto tensor = make_tensor_ptr(std::move(data));
205+
206+
EXPECT_EQ(tensor->dim(), 1);
207+
EXPECT_EQ(tensor->size(0), 4);
208+
EXPECT_EQ(tensor->strides()[0], 1);
209+
EXPECT_EQ(tensor->const_data_ptr<int64_t>()[0], 100);
210+
EXPECT_EQ(tensor->const_data_ptr<int64_t>()[3], 400);
211+
}
212+
213+
TEST_F(TensorPtrTest, TensorImplDataOnlyUint8Type) {
214+
std::vector<uint8_t> data = {10, 20, 30, 40};
215+
auto tensor = make_tensor_ptr(std::move(data));
216+
217+
EXPECT_EQ(tensor->dim(), 1);
218+
EXPECT_EQ(tensor->size(0), 4);
219+
EXPECT_EQ(tensor->strides()[0], 1);
220+
EXPECT_EQ(tensor->const_data_ptr<uint8_t>()[0], 10);
221+
EXPECT_EQ(tensor->const_data_ptr<uint8_t>()[3], 40);
222+
}
223+
224+
TEST_F(TensorPtrTest, TensorImplAmbiguityWithMixedVectors) {
225+
std::vector<exec_aten::SizesType> sizes = {2, 2};
226+
std::vector<float> data = {1.0f, 2.0f, 3.0f, 4.0f};
227+
auto tensor = make_tensor_ptr(std::move(sizes), std::move(data));
228+
229+
EXPECT_EQ(tensor->dim(), 2);
230+
EXPECT_EQ(tensor->size(0), 2);
231+
EXPECT_EQ(tensor->size(1), 2);
232+
EXPECT_EQ(tensor->strides()[0], 2);
233+
EXPECT_EQ(tensor->strides()[1], 1);
234+
EXPECT_EQ(tensor->const_data_ptr<float>()[0], 1.0f);
235+
EXPECT_EQ(tensor->const_data_ptr<float>()[3], 4.0f);
236+
237+
auto tensor2 = make_tensor_ptr({2, 2}, {1.0f, 2.0f, 3.0f, 4.0f});
238+
239+
EXPECT_EQ(tensor2->dim(), 2);
240+
EXPECT_EQ(tensor2->size(0), 2);
241+
EXPECT_EQ(tensor2->size(1), 2);
242+
EXPECT_EQ(tensor2->strides()[0], 2);
243+
EXPECT_EQ(tensor2->strides()[1], 1);
244+
EXPECT_EQ(tensor2->const_data_ptr<float>()[0], 1.0f);
245+
EXPECT_EQ(tensor2->const_data_ptr<float>()[3], 4.0f);
246+
}
247+
180248
TEST_F(TensorPtrTest, TensorSharingImplModifiesSharedDataVector) {
181249
std::vector<float> data = {1, 2, 3, 4, 5, 6};
182250

0 commit comments

Comments
 (0)