Skip to content

Commit 45210bb

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Fix tensor cloning when data is null. (#5535)
Summary: Pull Request resolved: #5535 . Reviewed By: dltn Differential Revision: D63201286 fbshipit-source-id: 1767a1c0cf876f7a3b6b4534a83c912c3de0eabf
1 parent 3ec4161 commit 45210bb

File tree

3 files changed

+46
-23
lines changed

3 files changed

+46
-23
lines changed

extension/tensor/tensor_ptr.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,39 @@
1313
namespace executorch {
1414
namespace extension {
1515

16+
TensorPtr clone_tensor_ptr(const exec_aten::Tensor& tensor) {
17+
std::vector<exec_aten::SizesType> sizes(
18+
tensor.sizes().begin(), tensor.sizes().end());
19+
std::vector<exec_aten::DimOrderType> dim_order{
20+
#ifndef USE_ATEN_LIB
21+
tensor.dim_order().begin(), tensor.dim_order().end()
22+
#endif // USE_ATEN_LIB
23+
};
24+
std::vector<exec_aten::StridesType> strides(
25+
tensor.strides().begin(), tensor.strides().end());
26+
auto dynamism = exec_aten::TensorShapeDynamism::DYNAMIC_BOUND;
27+
#ifndef USE_ATEN_LIB
28+
dynamism = tensor.shape_dynamism();
29+
#endif // USE_ATEN_LIB
30+
return tensor.const_data_ptr()
31+
? make_tensor_ptr(
32+
std::move(sizes),
33+
std::vector<uint8_t>(
34+
(uint8_t*)tensor.const_data_ptr(),
35+
(uint8_t*)tensor.const_data_ptr() + tensor.nbytes()),
36+
std::move(dim_order),
37+
std::move(strides),
38+
tensor.scalar_type(),
39+
dynamism)
40+
: make_tensor_ptr(
41+
std::move(sizes),
42+
nullptr,
43+
std::move(dim_order),
44+
std::move(strides),
45+
tensor.scalar_type(),
46+
dynamism);
47+
}
48+
1649
runtime::Error resize_tensor_ptr(
1750
TensorPtr& tensor,
1851
const std::vector<exec_aten::SizesType>& sizes) {

extension/tensor/tensor_ptr.h

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -397,34 +397,13 @@ inline TensorPtr make_tensor_ptr(
397397
/**
398398
* Creates a TensorPtr that manages a new Tensor with the same properties
399399
* as the given Tensor, but with a copy of the data owned by the returned
400-
* TensorPtr.
400+
* TensorPtr, or nullptr if the original data is null.
401401
*
402402
* @param tensor The Tensor to clone.
403403
* @return A new TensorPtr that manages a Tensor with the same properties as the
404404
* original but with copied data.
405405
*/
406-
inline TensorPtr clone_tensor_ptr(const exec_aten::Tensor& tensor) {
407-
return make_tensor_ptr(make_tensor_impl_ptr(
408-
std::vector<exec_aten::SizesType>(
409-
tensor.sizes().begin(), tensor.sizes().end()),
410-
std::vector<uint8_t>(
411-
(uint8_t*)tensor.const_data_ptr(),
412-
(uint8_t*)tensor.const_data_ptr() + tensor.nbytes()),
413-
#ifndef USE_ATEN_LIB
414-
std::vector<exec_aten::DimOrderType>(
415-
tensor.dim_order().begin(), tensor.dim_order().end()),
416-
std::vector<exec_aten::StridesType>(
417-
tensor.strides().begin(), tensor.strides().end()),
418-
tensor.scalar_type(),
419-
tensor.shape_dynamism()
420-
#else // USE_ATEN_LIB
421-
{},
422-
std::vector<exec_aten::StridesType>(
423-
tensor.strides().begin(), tensor.strides().end()),
424-
tensor.scalar_type()
425-
#endif // USE_ATEN_LIB
426-
));
427-
}
406+
TensorPtr clone_tensor_ptr(const exec_aten::Tensor& tensor);
428407

429408
/**
430409
* Creates a new TensorPtr by cloning the given TensorPtr, copying the

extension/tensor/test/tensor_ptr_test.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,17 @@ TEST_F(TensorPtrTest, CloneTensorPtrFromTensorPtrInt64) {
478478
EXPECT_EQ(cloned_tensor->scalar_type(), exec_aten::ScalarType::Long);
479479
}
480480

481+
TEST_F(TensorPtrTest, CloneTensorPtrFromTensorPtrNull) {
482+
auto tensor = make_tensor_ptr({2, 2}, nullptr);
483+
auto cloned_tensor = clone_tensor_ptr(tensor);
484+
485+
EXPECT_EQ(cloned_tensor->dim(), tensor->dim());
486+
EXPECT_EQ(cloned_tensor->size(0), tensor->size(0));
487+
EXPECT_EQ(cloned_tensor->size(1), tensor->size(1));
488+
EXPECT_EQ(cloned_tensor->const_data_ptr(), tensor->const_data_ptr());
489+
EXPECT_EQ(cloned_tensor->const_data_ptr(), nullptr);
490+
}
491+
481492
TEST_F(TensorPtrTest, TensorDataCastingFromIntToFloat) {
482493
std::vector<int32_t> int_data = {1, 2, 3, 4, 5, 6};
483494
auto tensor = make_tensor_ptr(

0 commit comments

Comments
 (0)