Skip to content

Commit e793795

Browse files
authored
[ET-VK] Add TmpTensorVRef struct to recycle temporary tensor memory
Differential Revision: D62144398 Pull Request resolved: #5041
1 parent 6ec5342 commit e793795

File tree

4 files changed

+251
-0
lines changed

4 files changed

+251
-0
lines changed

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,72 @@ VALUE_PTR_CLASS_IMPL(SymIntPtr, SymInt, SymInt)
4747

4848
#undef VALUE_PTR_CLASS_IMPL
4949

50+
//
51+
// TmpTensor
52+
//
53+
54+
TmpTensor::TmpTensor(
55+
ComputeGraph* const graph_ptr,
56+
const std::vector<int64_t>& sizes,
57+
const vkapi::ScalarType dtype,
58+
const utils::StorageType storage_type,
59+
const utils::GPUMemoryLayout memory_layout)
60+
: graph_p(graph_ptr),
61+
sobj_idx(get_sobj_idx()),
62+
vref(graph_p->add_tensor(
63+
sizes,
64+
dtype,
65+
storage_type,
66+
memory_layout,
67+
sobj_idx)) {}
68+
69+
TmpTensor::TmpTensor(
70+
ComputeGraph* const graph_ptr,
71+
const std::vector<int64_t>& sizes,
72+
const vkapi::ScalarType dtype,
73+
const utils::StorageType storage_type)
74+
: graph_p(graph_ptr),
75+
sobj_idx(get_sobj_idx()),
76+
vref(graph_p->add_tensor(sizes, dtype, storage_type, sobj_idx)) {}
77+
78+
TmpTensor::TmpTensor(
79+
ComputeGraph* const graph_ptr,
80+
const std::vector<int64_t>& sizes,
81+
const vkapi::ScalarType dtype,
82+
const utils::GPUMemoryLayout memory_layout)
83+
: graph_p(graph_ptr),
84+
sobj_idx(get_sobj_idx()),
85+
vref(graph_p->add_tensor(sizes, dtype, memory_layout, sobj_idx)) {}
86+
87+
TmpTensor::TmpTensor(
88+
ComputeGraph* const graph_ptr,
89+
const std::vector<int64_t>& sizes,
90+
const vkapi::ScalarType dtype)
91+
: graph_p(graph_ptr),
92+
sobj_idx(get_sobj_idx()),
93+
vref(graph_p->add_tensor(sizes, dtype, sobj_idx)) {}
94+
95+
TmpTensor::~TmpTensor() {
96+
// Lifetime of this temporary tensor is expired; return the shared object to
97+
// the pool, as long as the sobj index is valid
98+
if (sobj_idx >= 0) {
99+
graph_p->tmp_shared_object_idxs_.emplace(sobj_idx);
100+
}
101+
}
102+
103+
int64_t TmpTensor::get_sobj_idx() {
104+
int64_t sobj_idx;
105+
// If no available temporary shared objects, request a new one to be created
106+
if (graph_p->tmp_shared_object_idxs_.empty()) {
107+
sobj_idx = graph_p->shared_objects_.size();
108+
} else {
109+
// Get the first available shared object idx
110+
sobj_idx = graph_p->tmp_shared_object_idxs_.top();
111+
graph_p->tmp_shared_object_idxs_.pop();
112+
}
113+
return sobj_idx;
114+
}
115+
50116
//
51117
// ComputeGraph
52118
//

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName
1212

1313
#include <optional>
14+
#include <stack>
1415

1516
#include <executorch/backends/vulkan/runtime/api/api.h>
1617

@@ -67,6 +68,79 @@ DECL_VALUE_PTR_CLASS(SymIntPtr, SymInt);
6768

6869
#undef DECL_VALUE_PTR_CLASS
6970

71+
//
72+
// TmpTensor
73+
//
74+
75+
/*
76+
* This struct is used to recycle the memory of temporary tensors that are
77+
* created during the execution of a node. Upon construction, this struct will
78+
* check the `tmp_shared_object_idxs_` of the provided `ComputeGraph` instance
79+
* if any shared objects are available; if not, then a new one is created. A
80+
* tensor value is then added to the `ComputeGraph` instance with the requested
81+
* specifications. Upon destruction, the shared object index of the temporary
82+
* tensor is returned to `tmp_shared_object_idxs_`.
83+
*
84+
* Note that instances of this struct can be used as if they were `ValueRef` due
85+
* to implementation of a custom casting operator.
86+
*
87+
* This class should only be used to create tensors whose lifetimes exist only
88+
* in a well defined scope (i.e. within a function).
89+
*/
90+
struct TmpTensor {
91+
ComputeGraph* graph_p;
92+
int64_t sobj_idx;
93+
ValueRef vref;
94+
95+
//
96+
// Match all available overloads of `add_tensor`
97+
//
98+
99+
TmpTensor(
100+
ComputeGraph* const graph_ptr,
101+
const std::vector<int64_t>& sizes,
102+
const vkapi::ScalarType dtype,
103+
const utils::StorageType storage_type,
104+
const utils::GPUMemoryLayout memory_layout);
105+
106+
TmpTensor(
107+
ComputeGraph* const graph_ptr,
108+
const std::vector<int64_t>& sizes,
109+
const vkapi::ScalarType dtype,
110+
const utils::StorageType storage_type);
111+
112+
TmpTensor(
113+
ComputeGraph* const graph_ptr,
114+
const std::vector<int64_t>& sizes,
115+
const vkapi::ScalarType dtype,
116+
const utils::GPUMemoryLayout memory_layout);
117+
118+
TmpTensor(
119+
ComputeGraph* const graph_ptr,
120+
const std::vector<int64_t>& sizes,
121+
const vkapi::ScalarType dtype);
122+
123+
// No copy construction or assignment
124+
TmpTensor(TmpTensor& other) = delete;
125+
TmpTensor& operator=(TmpTensor& other) = delete;
126+
127+
// No move construction or assignment
128+
TmpTensor(TmpTensor&& other) = delete;
129+
TmpTensor& operator=(TmpTensor&& other) = delete;
130+
131+
// Custom cast to ValueRef
132+
operator ValueRef() const {
133+
return vref;
134+
};
135+
136+
~TmpTensor();
137+
138+
private:
139+
// Helper function to get first available shared object index or request a new
140+
// one to be created.
141+
int64_t get_sobj_idx();
142+
};
143+
70144
//
71145
// ComputeGraph
72146
//
@@ -94,7 +168,12 @@ class ComputeGraph final {
94168
vkapi::DescriptorPoolConfig execute_descriptor_counts_;
95169

96170
std::unique_ptr<api::Context> context_;
171+
97172
std::vector<SharedObject> shared_objects_;
173+
// This stack is used by `TmpTensor` instances to recycle shared objects
174+
// for temporary tensors. See the comments of `TmpTensor` for more details
175+
std::stack<int64_t> tmp_shared_object_idxs_;
176+
98177
std::vector<Value> values_;
99178
std::vector<api::ParamsBuffer> param_ubos_;
100179

@@ -593,6 +672,8 @@ class ComputeGraph final {
593672
friend class BoolListPtr;
594673
friend class ValueListPtr;
595674
friend class SymIntPtr;
675+
676+
friend struct TmpTensor;
596677
};
597678

598679
template <typename T>

backends/vulkan/runtime/graph/containers/Value.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ inline bool is_valid(ValueRef value_ref) {
2929
struct IOValueRef {
3030
ValueRef value;
3131
ValueRef staging;
32+
33+
// Custom cast to ValueRef
34+
operator ValueRef() const {
35+
return value;
36+
};
3237
};
3338

3439
/*

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,6 +1518,105 @@ TEST(VulkanComputeGraphTest, test_simple_shared_objects_with_resize) {
15181518
}
15191519
}
15201520

1521+
TEST(VulkanComputeGraphTest, test_simple_graph_with_tmp_tensors) {
1522+
GraphConfig config;
1523+
ComputeGraph graph(config);
1524+
1525+
std::vector<int64_t> size_big = {8, 64, 124};
1526+
std::vector<int64_t> size_small = {8, 1, 124};
1527+
1528+
// Build graph
1529+
1530+
IOValueRef a = graph.add_input_tensor(
1531+
size_big, vkapi::kFloat, /*shared_object_idx = */ 0);
1532+
IOValueRef b = graph.add_input_tensor(
1533+
size_small, vkapi::kFloat, /*shared_object_idx = */ 1);
1534+
1535+
IOValueRef out = {};
1536+
1537+
out.value =
1538+
graph.add_tensor(size_big, vkapi::kFloat, /*shared_object_idx = */ 2);
1539+
1540+
// Perform the following compute
1541+
//
1542+
// a, b, out;
1543+
// {
1544+
// inter;
1545+
// {
1546+
// tmp = a + b
1547+
// tmp2 = tmp + a
1548+
// inter = tmp2 + b
1549+
// }
1550+
// {
1551+
// tmp = inter + b;
1552+
// tmp2 = tmp + a
1553+
// out = tmp2 + b;
1554+
// }
1555+
// }
1556+
{
1557+
TmpTensor inter(&graph, size_big, vkapi::kFloat);
1558+
EXPECT_TRUE(inter.sobj_idx == 3);
1559+
{
1560+
TmpTensor tmp(&graph, size_big, vkapi::kFloat);
1561+
EXPECT_TRUE(tmp.sobj_idx == 4);
1562+
VK_GET_OP_FN("aten.add.Tensor")
1563+
(graph, {a, b, kDummyValueRef, tmp});
1564+
1565+
TmpTensor tmp2(&graph, size_big, vkapi::kFloat);
1566+
EXPECT_TRUE(tmp2.sobj_idx == 5);
1567+
VK_GET_OP_FN("aten.add.Tensor")
1568+
(graph, {tmp, a, kDummyValueRef, tmp2});
1569+
1570+
VK_GET_OP_FN("aten.add.Tensor")
1571+
(graph, {tmp2, b, kDummyValueRef, inter});
1572+
}
1573+
{
1574+
TmpTensor tmp(&graph, size_big, vkapi::kFloat);
1575+
EXPECT_TRUE(tmp.sobj_idx == 4);
1576+
VK_GET_OP_FN("aten.add.Tensor")
1577+
(graph, {inter, b, kDummyValueRef, tmp});
1578+
1579+
TmpTensor tmp2(&graph, size_big, vkapi::kFloat);
1580+
EXPECT_TRUE(tmp2.sobj_idx == 5);
1581+
VK_GET_OP_FN("aten.add.Tensor")
1582+
(graph, {tmp, a, kDummyValueRef, tmp2});
1583+
1584+
VK_GET_OP_FN("aten.add.Tensor")
1585+
(graph, {tmp2, b, kDummyValueRef, out});
1586+
}
1587+
}
1588+
1589+
out.staging = graph.set_output_tensor(out.value);
1590+
1591+
graph.prepare();
1592+
graph.encode_execute();
1593+
1594+
// Run graph
1595+
1596+
for (float i = 5.0f; i < 30.0f; i += 10.0f) {
1597+
float val_a = i + 2.0f;
1598+
float val_b = i + 1.5f;
1599+
float val_tmp = val_a + val_b;
1600+
float val_tmp2 = val_tmp + val_a;
1601+
float val_inter = val_tmp2 + val_b;
1602+
float val_tmp_2 = val_inter + val_b;
1603+
float val_tmp2_2 = val_tmp_2 + val_a;
1604+
float val_out = val_tmp2_2 + val_b;
1605+
1606+
fill_vtensor(graph, a, val_a);
1607+
fill_vtensor(graph, b, val_b);
1608+
1609+
graph.execute();
1610+
1611+
EXTRACT_TENSOR(out);
1612+
1613+
// Sanity check that the values are correct
1614+
for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); ++i) {
1615+
CHECK_VALUE(data_out, i, val_out);
1616+
}
1617+
}
1618+
}
1619+
15211620
TEST(VulkanComputeGraphTest, test_large_graph) {
15221621
auto build_start_time = std::chrono::system_clock::now();
15231622
GraphConfig config;

0 commit comments

Comments
 (0)