Skip to content

Commit cb71193

Browse files
authored
Make TensorImplPtr custom deleter copyable.
Differential Revision: D62338334 Pull Request resolved: #5161
1 parent 258cf71 commit cb71193

File tree

4 files changed

+99
-15
lines changed

4 files changed

+99
-15
lines changed

extension/tensor/tensor_impl_ptr.cpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,23 @@ static void noop_deleter(void*) {}
3030
* TensorImpl is destroyed.
3131
*/
3232
struct TensorImplPtrDeleter final {
33-
std::unique_ptr<void, std::function<void(void*)>> data;
34-
std::vector<exec_aten::SizesType> sizes;
35-
std::vector<exec_aten::DimOrderType> dim_order;
36-
std::vector<exec_aten::StridesType> strides;
33+
// A custom deleter of the std::shared_ptr is required to be copyable until
34+
// C++20, so any data it holds must be copyable too. Hence, we use shared_ptr
35+
// to hold the data and metadata to avoid unnecessary copies.
36+
std::shared_ptr<void> data;
37+
std::shared_ptr<std::vector<exec_aten::SizesType>> sizes;
38+
std::shared_ptr<std::vector<exec_aten::DimOrderType>> dim_order;
39+
std::shared_ptr<std::vector<exec_aten::StridesType>> strides;
3740

3841
void operator()(exec_aten::TensorImpl* pointer) {
3942
// Release all resources immediately since the data held by the
40-
// TensorImplDeleter is tied to the managed object, not the smart pointer
43+
// TensorImplPtrDeleter is tied to the managed object, not the smart pointer
4144
// itself. We need to free this memory when the object is destroyed, not
4245
// when the smart pointer (and deleter) are eventually destroyed or reset.
4346
data.reset();
44-
sizes = {};
45-
dim_order = {};
46-
strides = {};
47+
sizes.reset();
48+
dim_order.reset();
49+
strides.reset();
4750
delete pointer;
4851
}
4952
};
@@ -90,11 +93,13 @@ TensorImplPtr make_tensor_impl_ptr(
9093
return TensorImplPtr(
9194
tensor_impl.release(),
9295
TensorImplPtrDeleter{
93-
std::unique_ptr<void, std::function<void(void*)>>(
96+
std::shared_ptr<void>(
9497
data, deleter ? std::move(deleter) : noop_deleter),
95-
std::move(sizes),
96-
std::move(dim_order),
97-
std::move(strides)});
98+
std::make_shared<std::vector<exec_aten::SizesType>>(std::move(sizes)),
99+
std::make_shared<std::vector<exec_aten::DimOrderType>>(
100+
std::move(dim_order)),
101+
std::make_shared<std::vector<exec_aten::StridesType>>(
102+
std::move(strides))});
98103
#else
99104
auto options = c10::TensorOptions()
100105
.dtype(c10::scalarTypeToTypeMeta(type))

extension/tensor/tensor_impl_ptr.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,18 @@ TensorImplPtr make_tensor_impl_ptr(
9494
std::vector<exec_aten::StridesType> strides = {},
9595
exec_aten::TensorShapeDynamism dynamism =
9696
exec_aten::TensorShapeDynamism::STATIC) {
97-
const auto data_ptr = data.data();
97+
auto raw_data_ptr = data.data();
98+
auto data_ptr = std::make_shared<
99+
std::vector<typename runtime::ScalarTypeToCppType<T>::type>>(
100+
std::move(data));
98101
return make_tensor_impl_ptr(
99102
T,
100103
std::move(sizes),
101-
data_ptr,
104+
raw_data_ptr,
102105
std::move(dim_order),
103106
std::move(strides),
104107
dynamism,
105-
[data = std::move(data)](void*) {});
108+
[data_ptr = std::move(data_ptr)](void*) {});
106109
}
107110

108111
} // namespace extension

extension/tensor/test/tensor_impl_ptr_test.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,46 @@ TEST_F(TensorImplPtrTest, TensorImplOwningEmptyData) {
181181
EXPECT_EQ(tensor_impl->strides()[1], 1);
182182
EXPECT_EQ(tensor_impl->data(), nullptr);
183183
}
184+
185+
TEST_F(TensorImplPtrTest, SharedDataManagement) {
186+
auto data = std::make_shared<std::vector<float>>(100, 1.0f);
187+
auto tensor_impl1 = make_tensor_impl_ptr(
188+
exec_aten::ScalarType::Float, {10, 10}, data->data());
189+
auto tensor_impl2 = tensor_impl1;
190+
191+
EXPECT_EQ(tensor_impl1.get(), tensor_impl2.get());
192+
EXPECT_EQ(tensor_impl1.use_count(), 2);
193+
EXPECT_EQ(((float*)tensor_impl1->data())[0], 1.0f);
194+
195+
((float*)tensor_impl1->mutable_data())[0] = 2.0f;
196+
EXPECT_EQ(((float*)tensor_impl2->data())[0], 2.0f);
197+
198+
tensor_impl1.reset();
199+
EXPECT_NE(tensor_impl2.get(), nullptr);
200+
EXPECT_EQ(tensor_impl2.use_count(), 1);
201+
202+
EXPECT_EQ(((float*)tensor_impl2->data())[0], 2.0f);
203+
}
204+
205+
TEST_F(TensorImplPtrTest, CustomDeleterWithSharedData) {
206+
auto data = std::make_shared<std::vector<float>>(100, 1.0f);
207+
bool deleter_called = false;
208+
{
209+
auto tensor_impl = make_tensor_impl_ptr(
210+
exec_aten::ScalarType::Float,
211+
{10, 10},
212+
data->data(),
213+
{},
214+
{},
215+
exec_aten::TensorShapeDynamism::STATIC,
216+
[data, &deleter_called](void*) mutable {
217+
deleter_called = true;
218+
data.reset();
219+
});
220+
221+
EXPECT_EQ(data.use_count(), 2);
222+
EXPECT_FALSE(deleter_called);
223+
}
224+
EXPECT_TRUE(deleter_called);
225+
EXPECT_EQ(data.use_count(), 1);
226+
}

extension/tensor/test/tensor_ptr_test.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,36 @@ TEST_F(TensorPtrTest, TensorOwningEmptyData) {
176176
EXPECT_EQ(tensor->strides()[1], 1);
177177
EXPECT_EQ(tensor->data_ptr<float>(), nullptr);
178178
}
179+
180+
TEST_F(TensorPtrTest, TensorSharingImplModifiesSharedDataVector) {
181+
std::vector<float> data = {1, 2, 3, 4, 5, 6};
182+
183+
auto tensor1 = make_tensor_ptr({2, 3}, std::move(data));
184+
auto tensor2 = make_tensor_ptr(tensor1);
185+
186+
tensor1->mutable_data_ptr<float>()[0] = 10;
187+
EXPECT_EQ(tensor2->const_data_ptr<float>()[0], 10);
188+
189+
tensor2->mutable_data_ptr<float>()[5] = 20;
190+
EXPECT_EQ(tensor1->const_data_ptr<float>()[5], 20);
191+
}
192+
193+
TEST_F(TensorPtrTest, TensorSharingImplResizingAffectsBothVector) {
194+
std::vector<float> data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
195+
196+
auto tensor1 = make_tensor_ptr(
197+
{3, 4},
198+
std::move(data),
199+
{},
200+
{},
201+
exec_aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
202+
auto tensor2 = make_tensor_ptr(tensor1);
203+
204+
EXPECT_EQ(resize_tensor_ptr(tensor1, {2, 6}), Error::Ok);
205+
EXPECT_EQ(tensor2->size(0), 2);
206+
EXPECT_EQ(tensor2->size(1), 6);
207+
208+
EXPECT_EQ(resize_tensor_ptr(tensor2, {4, 3}), Error::Ok);
209+
EXPECT_EQ(tensor1->size(0), 4);
210+
EXPECT_EQ(tensor1->size(1), 3);
211+
}

0 commit comments

Comments
 (0)